diff --git a/TODO.md b/TODO.md index e4b6519..854a0f1 100644 --- a/TODO.md +++ b/TODO.md @@ -14,6 +14,8 @@ - [x] Extract - [x] Viz - [x] Math / Map Math + - [ ] Remove "By x" socket set let socket sets only be "Function"/"Expr"; then add a dynamic enum underneath to select "By x" based on data support. + - [ ] Filter the operations based on data support, ex. use positive-definiteness to guide cholesky. - [x] Math / Filter Math - [ ] Math / Reduce Math - [ ] Math / Operate Math @@ -34,8 +36,6 @@ - [x] Constants / Blender Constant - [ ] Web / Tidy3D Web Importer - - [ ] Change to output only a `FilePath`, which can be plugged into a Tidy3D File Importer. - - [ ] Implement caching, such that the file will only download if the file doesn't already exist. - [ ] Have a visual indicator for the current download status, with a manual re-download button. - [x] File Import / Material Import diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py index 4bd1c5d..8f37fd4 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py @@ -68,6 +68,9 @@ class FlowKind(enum.StrEnum): if kind == cls.LazyArrayRange: return value.rescale_to_unit(unit_system[socket_type]) + if kind == cls.Params: + return value.rescale_to_unit(unit_system[socket_type]) + msg = 'Tried to scale unknown kind' raise ValueError(msg) @@ -322,6 +325,28 @@ class LazyValueFuncFlow: supports_jax: bool = False supports_numba: bool = False + # Merging + def __or__( + self, + other: typ.Self, + ): + return LazyValueFuncFlow( + func=lambda *args, **kwargs: ( + self.func( + *list(args[: len(self.func_args)]), + **{k: v for k, v in kwargs.items() if k in self.func_kwargs}, + ), + other.func( + *list(args[len(self.func_args) :]), + **{k: v for k, v in kwargs.items() if k in other.func_kwargs}, + ), + ), + func_args=self.func_args + other.func_args, + func_kwargs=self.func_kwargs | other.func_kwargs, + supports_jax=self.supports_jax and other.supports_jax, + supports_numba=self.supports_numba and other.supports_numba, + ) + # Composition def compose_within( self, @@ -691,10 +716,21 @@ class ParamsFlow: func_args: list[typ.Any] = dataclasses.field(default_factory=list) func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict) + def __or__( + self, + other: typ.Self, + ): + return ParamsFlow( + func_args=self.func_args + other.func_args, + func_kwargs=self.func_kwargs | other.func_kwargs, + ) + def compose_within( self, enclosing_func_args: list[tuple[type]] = (), enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), + enclosing_func_arg_units: dict[str, type] = MappingProxyType({}), + enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}), ) -> typ.Self: return ParamsFlow( func_args=self.func_args + list(enclosing_func_args), @@ -718,26 +754,133 @@ class InfoFlow: return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} @functools.cached_property - def dim_mathtypes(self) -> dict[str, int]: + def dim_mathtypes(self) -> dict[str, spux.MathType]: return { dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() } @functools.cached_property - def dim_units(self) -> dict[str, int]: + def dim_units(self) -> dict[str, spux.Unit]: return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} @functools.cached_property - def dim_idx_arrays(self) -> list[ArrayFlow]: + def dim_idx_arrays(self) -> list[jax.Array]: return [ dim_idx.realize().values if isinstance(dim_idx, LazyArrayRangeFlow) else dim_idx.values for dim_idx in self.dim_idx.values() ] - return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} # Output Information - output_names: list[str] = dataclasses.field(default_factory=list) - output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict) - output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict) + output_name: str = dataclasses.field(default_factory=list) + output_shape: tuple[int, ...] | None = dataclasses.field(default=None) + output_mathtype: spux.MathType = dataclasses.field() + output_unit: spux.Unit | None = dataclasses.field() + + # Pinned Dimension Information + pinned_dim_names: list[str] = dataclasses.field(default_factory=list) + pinned_dim_values: dict[str, float | complex] = dataclasses.field( + default_factory=dict + ) + pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field( + default_factory=dict + ) + pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict) + + #################### + # - Methods + #################### + def delete_dimension(self, dim_name: str) -> typ.Self: + """Delete a dimension.""" + return InfoFlow( + # Dimensions + dim_names=[ + _dim_name for _dim_name in self.dim_names if _dim_name != dim_name + ], + dim_idx={ + _dim_name: dim_idx + for _dim_name, dim_idx in self.dim_idx.items() + if _dim_name != dim_name + }, + # Outputs + output_name=self.output_name, + output_shape=self.output_shape, + output_mathtype=self.output_mathtype, + output_unit=self.output_unit, + ) + + def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self: + """Delete a dimension.""" + + # Compute Swapped Dimension Name List + def name_swapper(dim_name): + return ( + dim_name + if dim_name not in [dim_0_name, dim_1_name] + else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name] + ) + + dim_names = [name_swapper(dim_name) for dim_name in self.dim_names] + + # Compute Info + return InfoFlow( + # Dimensions + dim_names=dim_names, + dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names}, + # Outputs + output_name=self.output_name, + output_shape=self.output_shape, + output_mathtype=self.output_mathtype, + output_unit=self.output_unit, + ) + + def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self: + """Set the MathType of a particular output name.""" + return InfoFlow( + dim_names=self.dim_names, + dim_idx=self.dim_idx, + # Outputs + output_name=self.output_name, + output_shape=self.output_shape, + output_mathtype=output_mathtype, + output_unit=self.output_unit, + ) + + def collapse_output( + self, + collapsed_name: str, + collapsed_mathtype: spux.MathType, + collapsed_unit: spux.Unit, + ) -> typ.Self: + return InfoFlow( + # Dimensions + dim_names=self.dim_names, + dim_idx=self.dim_idx, + output_name=collapsed_name, + output_shape=None, + output_mathtype=collapsed_mathtype, + output_unit=collapsed_unit, + ) + + @functools.cached_property + def shift_last_input(self): + """Shift the last input dimension to the output.""" + return InfoFlow( + # Dimensions + dim_names=self.dim_names[:-1], + dim_idx={ + dim_name: dim_idx + for dim_name, dim_idx in self.dim_idx.items() + if dim_name != self.dim_names[-1] + }, + # Outputs + output_name=self.output_name, + output_shape=( + (self.dim_lens[self.dim_names[-1]],) + if self.output_shape is None + else (self.dim_lens[self.dim_names[-1]], *self.output_shape) + ), + output_mathtype=self.output_mathtype, + output_unit=self.output_unit, + ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py index 5c5c00a..7429fcd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py @@ -15,6 +15,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum): FilterMath = enum.auto() ReduceMath = enum.auto() OperateMath = enum.auto() + TransformMath = enum.auto() # Inputs WaveConstant = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py index 1ea0d33..0b7c16c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py @@ -260,7 +260,9 @@ SOCKET_UNITS = { } -def unit_to_socket_type(unit: spux.Unit) -> ST: +def unit_to_socket_type( + unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None +) -> ST: """Returns a SocketType that accepts the given unit. Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned. @@ -269,6 +271,14 @@ def unit_to_socket_type(unit: spux.Unit) -> ST: Returns: **The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility. """ + if unit is None and fallback_mathtype is not None: + return { + spux.MathType.Integer: ST.IntegerNumber, + spux.MathType.Rational: ST.RationalNumber, + spux.MathType.Real: ST.RealNumber, + spux.MathType.Complex: ST.ComplexNumber, + }[fallback_mathtype] + for socket_type, _units in SOCKET_UNITS.items(): if unit in _units['values'].values(): return socket_type diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index cac504a..e60f48d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -215,6 +215,15 @@ class ExtractDataNode(base.MaxwellSimNode): #################### # - UI #################### + def draw_label(self): + has_sim_data = self.sim_data_monitor_nametype is not None + has_monitor_data = self.monitor_data_components is not None + + if has_sim_data or has_monitor_data: + return f'Extract: {self.extract_filter}' + + return self.bl_label + def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: """Draw node properties in the node. @@ -223,43 +232,6 @@ class ExtractDataNode(base.MaxwellSimNode): """ col.prop(self, self.blfields['extract_filter'], text='') - def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: - """Draw dynamic information in the node, for user consideration. - - Parameters: - col: UI target for drawing. - """ - has_sim_data = self.sim_data_monitor_nametype is not None - has_monitor_data = self.monitor_data_components is not None - - if has_sim_data or has_monitor_data: - # Header - row = col.row() - row.alignment = 'CENTER' - if has_sim_data: - row.label(text=f'{len(self.sim_data_monitor_nametype)} Monitors') - elif has_monitor_data: - row.label(text=f'{self.monitor_data_type} Monitor Data') - - # Monitor Data Contents - ## TODO: More compact double-split - ## TODO: Output shape data. - ## TODO: Local ENUM_MANY tabs for visible column selection? - row = col.row() - box = row.box() - grid = box.grid_flow(row_major=True, columns=2, even_columns=True) - if has_sim_data: - for ( - monitor_name, - monitor_type, - ) in self.sim_data_monitor_nametype.items(): - grid.label(text=monitor_name) - grid.label(text=monitor_type.replace('Data', '')) - elif has_monitor_data: - for component_name in self.monitor_data_components: - grid.label(text=component_name) - grid.label(text=self.monitor_data_type) - #################### # - Events #################### @@ -416,9 +388,8 @@ class ExtractDataNode(base.MaxwellSimNode): else: return ct.FlowSignal.FlowPending - info_output_names = { - 'output_names': [props['extract_filter']], - } + info_output_name = props['extract_filter'] + info_output_shape = None # Compute InfoFlow from XArray ## XYZF: Field / Permittivity / FieldProjectionCartesian @@ -442,13 +413,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Complex}, - output_units={ - props['extract_filter']: spu.volt / spu.micrometer + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Complex, + output_unit=( + spu.volt / spu.micrometer if props['monitor_data_type'] == 'Field' else None - }, + ), ) ## XYZT: FieldTime @@ -468,17 +440,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Complex}, - output_units={ - props['extract_filter']: ( - spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') - else spu.ampere / spu.micrometer - ) + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Complex, + output_unit=( + spu.volt / spu.micrometer if props['monitor_data_type'] == 'Field' else None - }, + ), ) ## F: Flux @@ -492,9 +461,10 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Real}, - output_units={props['extract_filter']: spu.watt}, + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Real, + output_unit=spu.watt, ) ## T: FluxTime @@ -508,9 +478,10 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Real}, - output_units={props['extract_filter']: spu.watt}, + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Real, + output_unit=spu.watt, ) ## RThetaPhiF: FieldProjectionAngle @@ -537,15 +508,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Real}, - output_units={ - props['extract_filter']: ( - spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') - else spu.ampere / spu.micrometer - ) - }, + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Real, + output_unit=( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ), ) ## UxUyRF: FieldProjectionKSpace @@ -570,15 +540,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Real}, - output_units={ - props['extract_filter']: ( - spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') - else spu.ampere / spu.micrometer - ) - }, + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Real, + output_unit=( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ), ) ## OrderxOrderyF: Diffraction @@ -600,15 +569,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - **info_output_names, - output_mathtypes={props['extract_filter']: spux.MathType.Real}, - output_units={ - props['extract_filter']: ( - spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') - else spu.ampere / spu.micrometer - ) - }, + output_name=props['extract_filter'], + output_shape=None, + output_mathtype=spux.MathType.Real, + output_unit=( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ), ) msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py index 0119e77..0517859 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py @@ -1,14 +1,16 @@ -from . import filter_math, map_math, operate_math, reduce_math +from . import filter_math, map_math, operate_math, reduce_math, transform_math BL_REGISTER = [ *map_math.BL_REGISTER, *filter_math.BL_REGISTER, *reduce_math.BL_REGISTER, *operate_math.BL_REGISTER, + *transform_math.BL_REGISTER, ] BL_NODES = { **map_math.BL_NODES, **filter_math.BL_NODES, **reduce_math.BL_NODES, **operate_math.BL_NODES, + **transform_math.BL_NODES, } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 20281c6..4407ba4 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -1,3 +1,5 @@ +"""Declares `FilterMathNode`.""" + import enum import typing as typ @@ -15,11 +17,21 @@ log = logger.get(__name__) class FilterMathNode(base.MaxwellSimNode): - """Reduces the dimensionality of data. + r"""Applies a function that operates on the shape of the array. + + The shape, type, and interpretation of the input/output data is dynamically shown. + + # Socket Sets + ## Dimensions + Alter the dimensions of the array. + + ## Interpret + Only alter the interpretation of the array data, which guides what it can be used for. + + These operations are **zero cost**, since the data itself is untouched. Attributes: operation: Operation to apply to the input. - dim: Dims to use when filtering data """ node_type = ct.NodeType.FilterMath @@ -29,8 +41,8 @@ class FilterMathNode(base.MaxwellSimNode): 'Data': sockets.DataSocketDef(format='jax'), } input_socket_sets: typ.ClassVar = { - 'By Dim': {}, - 'By Dim Value': {}, + 'Interpret': {}, + 'Dimensions': {}, } output_sockets: typ.ClassVar = { 'Data': sockets.DataSocketDef(format='jax'), @@ -43,10 +55,17 @@ class FilterMathNode(base.MaxwellSimNode): prop_ui=True, enum_cb=lambda self, _: self.search_operations() ) - dim: enum.Enum = bl_cache.BLField( + # Dimension Selection + dim_0: enum.Enum = bl_cache.BLField( + None, prop_ui=True, enum_cb=lambda self, _: self.search_dims() + ) + dim_1: enum.Enum = bl_cache.BLField( None, prop_ui=True, enum_cb=lambda self, _: self.search_dims() ) + #################### + # - Computed + #################### @property def data_info(self) -> ct.InfoFlow | None: info = self._compute_input('Data', kind=ct.FlowKind.Info) @@ -60,87 +79,119 @@ class FilterMathNode(base.MaxwellSimNode): #################### def search_operations(self) -> list[tuple[str, str, str]]: items = [] - if self.active_socket_set == 'By Dim': + if self.active_socket_set == 'Interpret': items += [ - ('SQUEEZE', 'del a | #=1', 'Squeeze'), + ('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'), + ('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'), ] - if self.active_socket_set == 'By Dim Value': + elif self.active_socket_set == 'Dimensions': items += [ - ('FIX', 'del a | i≈v', 'Fix Coordinate'), + ('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'), + ( + 'PIN', + 'pinₐ ≈v', + 'Remove a len(n) dimension by selecting an index', + ), + ('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'), ] return [(*item, '', i) for i, item in enumerate(items)] #################### - # - Dim Search + # - Dimensions Search #################### def search_dims(self) -> list[ct.BLEnumElement]: - if self.data_info is not None: + if self.data_info is None: + return [] + + if self.operation == 'PIN_LEN_ONE': dims = [ - (dim_name, dim_name, dim_name, '', i) - for i, dim_name in enumerate(self.data_info.dim_names) + (dim_name, dim_name, f'Dimension "{dim_name}" of length 1') + for dim_name in self.data_info.dim_names + if self.data_info.dim_lens[dim_name] == 1 ] + elif self.operation in ['PIN', 'SWAP']: + dims = [ + (dim_name, dim_name, f'Dimension "{dim_name}"') + for dim_name in self.data_info.dim_names + ] + else: + return [] - # Squeeze: Dimension Must Have Length=1 - ## We must also correct the "NUMBER" of the enum. - if self.operation == 'SQUEEZE': - filtered_dims = [ - dim for dim in dims if self.data_info.dim_lens[dim[0]] == 1 - ] - return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)] - - return dims - return [] + return [(*dim, '', i) for i, dim in enumerate(dims)] #################### # - UI #################### + def draw_label(self): + labels = { + 'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)', + 'PIN': lambda: f'Filter: Pin {self.dim_0}', + 'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}', + 'DIM_TO_VEC': lambda: 'Filter: -> Vector', + 'DIMS_TO_MAT': lambda: 'Filter: -> Matrix', + } + + if (label := labels.get(self.operation)) is not None: + return label() + + return self.bl_label + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') - if self.data_info is not None and self.data_info.dim_names: - layout.prop(self, self.blfields['dim'], text='') + + if self.active_socket_set == 'Dimensions': + if self.operation in ['PIN_LEN_ONE', 'PIN']: + layout.prop(self, self.blfields['dim_0'], text='') + + if self.operation == 'SWAP': + row = layout.row(align=True) + row.prop(self, self.blfields['dim_0'], text='') + row.prop(self, self.blfields['dim_1'], text='') #################### # - Events #################### @events.on_value_changed( - socket_name='Data', prop_name='active_socket_set', run_on_init=True, - input_sockets={'Data'}, ) - def on_any_change(self, input_sockets: dict): - if all( - not ct.FlowSignal.check_single( - input_socket_value, ct.FlowSignal.FlowPending - ) - for input_socket_value in input_sockets.values() - ): - self.operation = bl_cache.Signal.ResetEnumItems - self.dim = bl_cache.Signal.ResetEnumItems + def on_socket_set_changed(self): + self.operation = bl_cache.Signal.ResetEnumItems + + @events.on_value_changed( + # Trigger + socket_name='Data', + prop_name={'active_socket_set', 'operation'}, + run_on_init=True, + # Loaded + props={'operation'}, + ) + def on_any_change(self, props: dict) -> None: + self.dim_0 = bl_cache.Signal.ResetEnumItems + self.dim_1 = bl_cache.Signal.ResetEnumItems @events.on_value_changed( socket_name='Data', - prop_name='dim', + prop_name={'dim_0', 'dim_1', 'operation'}, ## run_on_init: Implicitly triggered. - props={'active_socket_set', 'dim'}, + props={'operation', 'dim_0', 'dim_1'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, ) def on_dim_change(self, props: dict, input_sockets: dict): - if input_sockets['Data'] == ct.FlowSignal.FlowPending: + has_data = not ct.FlowSignal.check(input_sockets['Data']) + if not has_data: return - # Add/Remove Input Socket "Value" - if ( - not ct.Flowsignal.check(input_sockets['Data']) - and props['active_socket_set'] == 'By Dim Value' - and props['dim'] != 'NONE' - ): + # "Dimensions"|"PIN": Add/Remove Input Socket + if props['operation'] == 'PIN' and props['dim_0'] != 'NONE': # Get Current and Wanted Socket Defs current_bl_socket = self.loose_input_sockets.get('Value') wanted_socket_def = sockets.SOCKET_DEFS[ - ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit) + ct.unit_to_socket_type( + input_sockets['Data'].dim_idx[props['dim_0']].unit + ) ] # Determine Whether to Declare New Loose Input SOcket @@ -151,7 +202,7 @@ class FilterMathNode(base.MaxwellSimNode): ): self.loose_input_sockets = { 'Value': wanted_socket_def(), - } ## TODO: Can we do the boilerplate in base.py? + } elif self.loose_input_sockets: self.loose_input_sockets = {} @@ -161,40 +212,51 @@ class FilterMathNode(base.MaxwellSimNode): @events.computes_output_socket( 'Data', kind=ct.FlowKind.LazyValueFunc, - props={'active_socket_set', 'operation', 'dim'}, + props={'operation', 'dim_0', 'dim_1'}, input_sockets={'Data'}, input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, ) def compute_data(self, props: dict, input_sockets: dict): - # Retrieve Inputs lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc] info = input_sockets['Data'][ct.FlowKind.Info] # Check Flow - if ( - any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]) - or props['operation'] == 'NONE' - ): + if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]): return ct.FlowSignal.FlowPending - # Compute Bound/Free Parameters - func_args = [int] if props['active_socket_set'] == 'By Dim Value' else [] - axis = info.dim_names.index(props['dim']) + # Compute Function Arguments + operation = props['operation'] + if operation == 'NONE': + return ct.FlowSignal.FlowPending - # Select Function - filter_func: typ.Callable[[jax.Array], jax.Array] = { - 'By Dim': {'SQUEEZE': lambda data: jnp.squeeze(data, axis)}, - 'By Dim Value': { - 'FIX': lambda data, fixed_axis_idx: jnp.take( - data, fixed_axis_idx, axis=axis - ) - }, - }[props['active_socket_set']][props['operation']] + ## Dimension(s) + dim_0 = props['dim_0'] + dim_1 = props['dim_1'] + if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': + return ct.FlowSignal.FlowPending + if operation == 'SWAP' and dim_1 == 'NONE': + return ct.FlowSignal.FlowPending + + ## Axis/Axes + axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None + axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None + + # Compose Output Function + filter_func = { + # Dimensions + 'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0), + 'PIN': lambda data, fixed_axis_idx: jnp.take( + data, fixed_axis_idx, axis=axis_0 + ), + 'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1), + # Interpret + 'DIM_TO_VEC': lambda data: data, + 'DIMS_TO_MAT': lambda data: data, + }[props['operation']] - # Compose Function for Output return lazy_value_func.compose_within( filter_func, - enclosing_func_args=func_args, + enclosing_func_args=[int] if operation == 'PIN' else [], supports_jax=True, ) @@ -207,7 +269,6 @@ class FilterMathNode(base.MaxwellSimNode): }, ) def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: - # Retrieve Inputs lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] params = output_sockets['Data'][ct.FlowKind.Params] @@ -215,57 +276,54 @@ class FilterMathNode(base.MaxwellSimNode): if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): return ct.FlowSignal.FlowPending - # Compute Array return ct.ArrayFlow( values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs), - unit=None, ## TODO: Unit Propagation + unit=None, ) #################### - # - Compute Auxiliary: Info / Params + # - Compute Auxiliary: Info #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.Info, - props={'active_socket_set', 'dim', 'operation'}, + props={'dim_0', 'dim_1', 'operation'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, ) def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: - # Retrieve Inputs info = input_sockets['Data'] # Check Flow - if ct.FlowSignal.check(info) or props['dim'] == 'NONE': + if ct.FlowSignal.check(info): return ct.FlowSignal.FlowPending - # Compute Information - ## Compute Info w/By-Operation Change to Dimensions - axis = info.dim_names.index(props['dim']) + # Collect Information + dim_0 = props['dim_0'] + dim_1 = props['dim_1'] - if (props['active_socket_set'], props['operation']) in [ - ('By Dim', 'SQUEEZE'), - ('By Dim Value', 'FIX'), - ] and info.dim_names: - return ct.InfoFlow( - dim_names=info.dim_names[:axis] + info.dim_names[axis + 1 :], - dim_idx={ - dim_name: dim_idx - for dim_name, dim_idx in info.dim_idx.items() - if dim_name != props['dim'] - }, - output_names=info.output_names, - output_mathtypes=info.output_mathtypes, - output_units=info.output_units, - ) + if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': + return ct.FlowSignal.FlowPending + if props['operation'] == 'SWAP' and dim_1 == 'NONE': + return ct.FlowSignal.FlowPending - msg = f'Active socket set {props["active_socket_set"]} and operation {props["operation"]} don\'t have an InfoFlow defined' - raise RuntimeError(msg) + return { + # Dimensions + 'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0), + 'PIN': lambda: info.delete_dimension(dim_0), + 'SWAP': lambda: info.swap_dimensions(dim_0, dim_1), + # Interpret + 'DIM_TO_VEC': lambda: info.shift_last_input, + 'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input, + }[props['operation']]() + #################### + # - Compute Auxiliary: Info + #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.Params, - props={'active_socket_set', 'dim', 'operation'}, + props={'dim_0', 'dim_1', 'operation'}, input_sockets={'Data', 'Value'}, input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}}, input_sockets_optional={'Value': True}, @@ -273,35 +331,33 @@ class FilterMathNode(base.MaxwellSimNode): def compute_composed_params( self, props: dict, input_sockets: dict ) -> ct.ParamsFlow: - # Retrieve Inputs info = input_sockets['Data'][ct.FlowKind.Info] params = input_sockets['Data'][ct.FlowKind.Params] + # Check Flow if any(ct.FlowSignal.check(inp) for inp in [info, params]): return ct.FlowSignal.FlowPending - # Compute Composed Parameters - ## -> Only operations that add parameters. - ## -> A dimension must be selected. - ## -> There must be an input value. - if ( - (props['active_socket_set'], props['operation']) - in [ - ('By Dim Value', 'FIX'), - ] - and props['dim'] != 'NONE' - and not ct.FlowSignal.check(input_sockets['Value']) - ): - # Compute IDX Corresponding to Coordinate Value - ## -> Each dimension declares a unit-aware real number at each index. - ## -> "Value" is a unit-aware real number from loose input socket. - ## -> This finds the dimensional index closest to "Value". - ## Total Effect: Indexing by a unit-aware real number. - nearest_idx_to_value = info.dim_idx[props['dim']].nearest_idx_of( + # Collect Information + ## Dimensions + dim_0 = props['dim_0'] + dim_1 = props['dim_1'] + + if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': + return ct.FlowSignal.FlowPending + if props['operation'] == 'SWAP' and dim_1 == 'NONE': + return ct.FlowSignal.FlowPending + + ## Pinned Value + pinned_value = input_sockets['Value'] + has_pinned_value = not ct.FlowSignal.check(pinned_value) + + if props['operation'] == 'PIN' and has_pinned_value: + # Compute IDX Corresponding to Dimension Index + nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( input_sockets['Value'], require_sorted=True ) - # Compose Parameters return params.compose_within(enclosing_func_args=[nearest_idx_to_value]) return params diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 8752965..58f47e0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -138,6 +138,7 @@ class MapMathNode(base.MaxwellSimNode): ('SQ', 'v²', 'v^2 (by el)'), ('SQRT', '√v', 'sqrt(v) (by el)'), ('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'), + None, # Trigonometry ('COS', 'cos v', 'cos(v) (by el)'), ('SIN', 'sin v', 'sin(v) (by el)'), @@ -148,6 +149,7 @@ class MapMathNode(base.MaxwellSimNode): ] elif self.active_socket_set in 'By Vector': items = [ + # Vector -> Number ('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'), ] elif self.active_socket_set == 'By Matrix': @@ -157,13 +159,16 @@ class MapMathNode(base.MaxwellSimNode): ('COND', 'κ(V)', 'cond(V) (by Mat)'), ('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'), ('RANK', 'rank V', 'rank(V) (by Mat)'), + None, # Matrix -> Array ('DIAG', 'diag V', 'diag(V) (by Mat)'), ('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'), ('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'), + None, # Matrix -> Matrix ('INV', 'V⁻¹', 'V^(-1) (by Mat)'), ('TRA', 'Vt', 'V^T (by Mat)'), + None, # Matrix -> Matrices ('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'), ('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'), @@ -175,7 +180,9 @@ class MapMathNode(base.MaxwellSimNode): msg = f'Active socket set {self.active_socket_set} is unknown' raise RuntimeError(msg) - return [(*item, '', i) for i, item in enumerate(items)] + return [ + (*item, '', i) if item is not None else None for i, item in enumerate(items) + ] def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') @@ -185,8 +192,9 @@ class MapMathNode(base.MaxwellSimNode): #################### @events.on_value_changed( prop_name='active_socket_set', + run_on_init=True, ) - def on_operation_changed(self): + def on_socket_set_changed(self): self.operation = bl_cache.Signal.ResetEnumItems #################### @@ -204,11 +212,14 @@ class MapMathNode(base.MaxwellSimNode): input_sockets_optional={'Mapper': True}, ) def compute_data(self, props: dict, input_sockets: dict): + has_data = not ct.FlowSignal.check(input_sockets['Data']) if ( - ct.FlowSignal.check(input_sockets['Data']) or props['operation'] == 'NONE' - ) or ( - props['active_socket_set'] == 'Expr' - and ct.FlowSignal.check(input_sockets['Mapper']) + not has_data + or props['operation'] == 'NONE' + or ( + props['active_socket_set'] == 'Expr' + and ct.FlowSignal.check(input_sockets['Mapper']) + ) ): return ct.FlowSignal.FlowPending @@ -238,14 +249,14 @@ class MapMathNode(base.MaxwellSimNode): 'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'), 'RANK': lambda data: jnp.linalg.matrix_rank(data), # Matrix -> Vec - 'DIAG': lambda data: jnp.diag(data), - 'EIG_VALS': lambda data: jnp.eigvals(data), - 'SVD_VALS': lambda data: jnp.svdvals(data), + 'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1), + 'EIG_VALS': lambda data: jnp.linalg.eigvals(data), + 'SVD_VALS': lambda data: jnp.linalg.svdvals(data), # Matrix -> Matrix - 'INV': lambda data: jnp.inv(data), + 'INV': lambda data: jnp.linalg.inv(data), 'TRA': lambda data: jnp.matrix_transpose(data), # Matrix -> Matrices - 'QR': lambda data: jnp.inv(data), + 'QR': lambda data: jnp.linalg.qr(data), 'CHOL': lambda data: jnp.linalg.cholesky(data), 'SVD': lambda data: jnp.linalg.svd(data), }, @@ -298,28 +309,53 @@ class MapMathNode(base.MaxwellSimNode): return ct.FlowSignal.FlowPending # Complex -> Real - if props['active_socket_set'] == 'By Element': - if props['operation'] in [ - 'REAL', - 'IMAG', - 'ABS', - ]: - return ct.InfoFlow( - dim_names=info.dim_names, - dim_idx=info.dim_idx, - output_names=info.output_names, - output_mathtypes={ - output_name: ( - spux.MathType.Real - if output_mathtype == spux.MathType.Complex - else output_mathtype - ) - for output_name, output_mathtype in info.output_mathtypes.items() - }, - output_units=info.output_units, + if props['active_socket_set'] == 'By Element' and props['operation'] in [ + 'REAL', + 'IMAG', + 'ABS', + ]: + return info.set_output_mathtype(spux.MathType.Real) + + if props['active_socket_set'] == 'By Vector' and props['operation'] in [ + 'NORM_2' + ]: + return { + 'NORM_2': lambda: info.collapse_output( + collapsed_name=f'||{info.output_name}||₂', + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=info.output_unit, ) - if props['active_socket_set'] == 'By Vector': - pass + }[props['operation']]() + + if props['active_socket_set'] == 'By Matrix' and props['operation'] in [ + 'DET', + 'COND', + 'NORM_FRO', + 'RANK', + ]: + return { + 'DET': lambda: info.collapse_output( + collapsed_name=f'det {info.output_name}', + collapsed_mathtype=info.output_mathtype, + collapsed_unit=info.output_unit, + ), + 'COND': lambda: info.collapse_output( + collapsed_name=f'κ({info.output_name})', + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=None, + ), + 'NORM_FRO': lambda: info.collapse_output( + collapsed_name=f'||({info.output_name}||_F', + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=info.output_unit, + ), + 'RANK': lambda: info.collapse_output( + collapsed_name=f'rank {info.output_name}', + collapsed_mathtype=spux.MathType.Integer, + collapsed_unit=None, + ), + }[props['operation']]() + return info @events.computes_output_socket( diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py index 512f639..774ec02 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -1,9 +1,12 @@ +import enum import typing as typ import bpy import jax.numpy as jnp +import sympy as sp -from blender_maxwell.utils import logger +from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct from .... import sockets @@ -13,120 +16,465 @@ log = logger.get(__name__) class OperateMathNode(base.MaxwellSimNode): + r"""Applies a function that depends on two inputs. + + Attributes: + category: The category of operations to apply to the inputs. + **Only valid** categories can be chosen. + operation: The actual operation to apply to the inputs. + **Only valid** operations can be chosen. + """ + node_type = ct.NodeType.OperateMath bl_label = 'Operate Math' input_socket_sets: typ.ClassVar = { - 'Elementwise': { - 'Data L': sockets.AnySocketDef(), - 'Data R': sockets.AnySocketDef(), + 'Expr | Expr': { + 'Expr L': sockets.ExprSocketDef(), + 'Expr R': sockets.ExprSocketDef(), }, - ## TODO: Filter-array building operations - 'Vec-Vec': { - 'Data L': sockets.AnySocketDef(), - 'Data R': sockets.AnySocketDef(), + 'Data | Data': { + 'Data L': sockets.DataSocketDef( + format='jax', default_show_info_columns=False + ), + 'Data R': sockets.DataSocketDef( + format='jax', default_show_info_columns=False + ), }, - 'Mat-Vec': { - 'Data L': sockets.AnySocketDef(), - 'Data R': sockets.AnySocketDef(), + 'Expr | Data': { + 'Expr L': sockets.ExprSocketDef(), + 'Data R': sockets.DataSocketDef( + format='jax', default_show_info_columns=False + ), }, } - output_sockets: typ.ClassVar = { - 'Data': sockets.AnySocketDef(), + output_socket_sets: typ.ClassVar = { + 'Expr | Expr': { + 'Expr': sockets.ExprSocketDef(), + }, + 'Data | Data': { + 'Data': sockets.DataSocketDef( + format='jax', default_show_info_columns=False + ), + }, + 'Expr | Data': { + 'Data': sockets.DataSocketDef( + format='jax', default_show_info_columns=False + ), + }, } #################### # - Properties #################### - operation: bpy.props.EnumProperty( - name='Op', - description='Operation to apply to the two inputs', - items=lambda self, _: self.search_operations(), - update=lambda self, context: self.on_prop_changed('operation', context), + category: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_categories() ) - def search_operations(self) -> list[tuple[str, str, str]]: + operation: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_operations() + ) + + def search_categories(self) -> list[ct.BLEnumElement]: + """Deduce and return a list of valid categories for the current socket set and input data.""" + data_l_info = self._compute_input( + 'Data L', kind=ct.FlowKind.Info, optional=True + ) + data_r_info = self._compute_input( + 'Data R', kind=ct.FlowKind.Info, optional=True + ) + + has_data_l_info = not ct.FlowSignal.check(data_l_info) + has_data_r_info = not ct.FlowSignal.check(data_r_info) + + # Categories by Socket Set + NUMBER_NUMBER = ( + 'Number | Number', + 'Number | Number', + 'Operations between numerical elements', + ) + NUMBER_VECTOR = ( + 'Number | Vector', + 'Number | Vector', + 'Operations between numerical and vector elements', + ) + NUMBER_MATRIX = ( + 'Number | Matrix', + 'Number | Matrix', + 'Operations between numerical and matrix elements', + ) + VECTOR_VECTOR = ( + 'Vector | Vector', + 'Vector | Vector', + 'Operations between vector elements', + ) + MATRIX_VECTOR = ( + 'Matrix | Vector', + 'Matrix | Vector', + 'Operations between vector and matrix elements', + ) + MATRIX_MATRIX = ( + 'Matrix | Matrix', + 'Matrix | Matrix', + 'Operations between matrix elements', + ) + categories = [] + + ## Expr | Expr + if self.active_socket_set == 'Expr | Expr': + return [NUMBER_NUMBER] + + ## Data | Data + if ( + self.active_socket_set == 'Data | Data' + and has_data_l_info + and has_data_r_info + ): + # Check Valid Broadcasting + ## Number | Number + if data_l_info.output_shape is None and data_r_info.output_shape is None: + categories = [NUMBER_NUMBER] + + ## Number | Vector + elif ( + data_l_info.output_shape is None and len(data_r_info.output_shape) == 1 + ): + categories = [NUMBER_VECTOR] + + ## Number | Matrix + elif ( + data_l_info.output_shape is None and len(data_r_info.output_shape) == 2 + ): # noqa: PLR2004 + categories = [NUMBER_MATRIX] + + ## Vector | Vector + elif ( + len(data_l_info.output_shape) == 1 + and len(data_r_info.output_shape) == 1 + ): + categories = [VECTOR_VECTOR] + + ## Matrix | Vector + elif ( + len(data_l_info.output_shape) == 2 # noqa: PLR2004 + and len(data_r_info.output_shape) == 1 + ): + categories = [MATRIX_VECTOR] + + ## Matrix | Matrix + elif ( + len(data_l_info.output_shape) == 2 # noqa: PLR2004 + and len(data_r_info.output_shape) == 2 # noqa: PLR2004 + ): + categories = [MATRIX_MATRIX] + + ## Expr | Data + if self.active_socket_set == 'Expr | Data' and has_data_r_info: + if data_r_info.output_shape is None: + categories = [NUMBER_NUMBER] + else: + categories = { + 1: [NUMBER_NUMBER, NUMBER_VECTOR], + 2: [NUMBER_NUMBER, NUMBER_MATRIX], + }[len(data_r_info.output_shape)] + + return [ + (*category, '', i) if category is not None else None + for i, category in enumerate(categories) + ] + + def search_operations(self) -> list[ct.BLEnumElement]: items = [] - if self.active_socket_set == 'Elementwise': - items = [ - ('ADD', 'Add', 'L + R (by el)'), - ('SUB', 'Subtract', 'L - R (by el)'), - ('MUL', 'Multiply', 'L · R (by el)'), - ('DIV', 'Divide', 'L ÷ R (by el)'), - ('POW', 'Power', 'L^R (by el)'), - ('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'), - ('ATAN2', 'atan2', 'atan2(L,R) (by el)'), - ('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'), + if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']: + items += [ + ('ADD', 'L + R', 'Add'), + ('SUB', 'L - R', 'Subtract'), + ('MUL', 'L · R', 'Multiply'), + ('DIV', 'L ÷ R', 'Divide'), + ('POW', 'L^R', 'Power'), + ('ATAN2', 'atan2(L,R)', 'atan2(L,R)'), ] - elif self.active_socket_set in 'Vec | Vec': - items = [ - ('DOT', 'Dot', 'L · R'), - ('CROSS', 'Cross', 'L x R (by last-axis'), + if self.category in 'Vector | Vector': + if items: + items += [None] + items += [ + ('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'), + ('CROSS', 'L x R', 'Cross Product'), + ('PROJ', 'proj(L, R)', 'Projection'), ] - elif self.active_socket_set == 'Mat | Vec': - items = [ - ('DOT', 'Dot', 'L · R'), - ('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'), - ('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'), + if self.category == 'Matrix | Vector': + if items: + items += [None] + items += [ + ('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'), + ('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'), + ('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'), ] - return items + if self.category == 'Matrix | Matrix': + if items: + items += [None] + items += [ + ('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'), + ] + + return [ + (*item, '', i) if item is not None else None for i, item in enumerate(items) + ] def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: - layout.prop(self, 'operation') + layout.prop(self, self.blfields['category'], text='') + layout.prop(self, self.blfields['operation'], text='') #################### - # - Properties + # - Events + #################### + @events.on_value_changed( + # Trigger + socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'}, + prop_name='active_socket_set', + run_on_init=True, + ) + def on_socket_set_changed(self) -> None: + # Recompute Valid Categories + self.category = bl_cache.Signal.ResetEnumItems + self.operation = bl_cache.Signal.ResetEnumItems + + @events.on_value_changed( + prop_name='category', + run_on_init=True, + ) + def on_category_changed(self) -> None: + self.operation = bl_cache.Signal.ResetEnumItems + + #################### + # - Output + #################### + @events.computes_output_socket( + 'Expr', + kind=ct.FlowKind.Value, + props={'operation'}, + input_sockets={'Expr L', 'Expr R'}, + ) + def compute_expr(self, props: dict, input_sockets: dict): + expr_l = input_sockets['Expr L'] + expr_r = input_sockets['Expr R'] + + return { + 'ADD': lambda: expr_l + expr_r, + 'SUB': lambda: expr_l - expr_r, + 'MUL': lambda: expr_l * expr_r, + 'DIV': lambda: expr_l / expr_r, + 'POW': lambda: expr_l**expr_r, + 'ATAN2': lambda: sp.atan2(expr_r, expr_l), + }[props['operation']]() + + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.LazyValueFunc, + props={'operation'}, + input_sockets={'Data L', 'Data R'}, + input_socket_kinds={ + 'Data L': ct.FlowKind.LazyValueFunc, + 'Data R': ct.FlowKind.LazyValueFunc, + }, + input_sockets_optional={ + 'Data L': True, + 'Data R': True, + }, + ) + def compute_data(self, props: dict, input_sockets: dict): + data_l = input_sockets['Data L'] + data_r = input_sockets['Data R'] + has_data_l = not ct.FlowSignal.check(data_l) + + mapping_func = { + # Number | * + 'ADD': lambda datas: datas[0] + datas[1], + 'SUB': lambda datas: datas[0] - datas[1], + 'MUL': lambda datas: datas[0] * datas[1], + 'DIV': lambda datas: datas[0] / datas[1], + 'POW': lambda datas: datas[0] ** datas[1], + 'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]), + # Vector | Vector + 'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), + 'CROSS': lambda datas: jnp.cross(datas[0], datas[1]), + # Matrix | Vector + 'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), + 'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]), + 'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]), + # Matrix | Matrix + 'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), + }[props['operation']] + + # Compose by Socket Set + ## Data | Data + if has_data_l: + return (data_l | data_r).compose_within( + mapping_func, + supports_jax=True, + ) + + ## Expr | Data + expr_l_lazy_value_func = ct.LazyValueFuncFlow( + func=lambda expr_l_value: expr_l_value, + func_args=[typ.Any], + supports_jax=True, + ) + return (expr_l_lazy_value_func | data_r).compose_within( + mapping_func, + supports_jax=True, + ) + + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.Array, + output_sockets={'Data'}, + output_socket_kinds={ + 'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, + }, + ) + def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: + lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] + params = output_sockets['Data'][ct.FlowKind.Params] + + has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) + has_params = not ct.FlowSignal.check(params) + + if has_lazy_value_func and has_params: + return ct.ArrayFlow( + values=lazy_value_func.func_jax( + *params.func_args, **params.func_kwargs + ), + unit=None, + ) + + return ct.FlowSignal.FlowPending + + #################### + # - Auxiliary: Params #################### @events.computes_output_socket( 'Data', + kind=ct.FlowKind.Params, props={'operation'}, - input_sockets={'Data L', 'Data R'}, + input_sockets={'Expr L', 'Data L', 'Data R'}, + input_socket_kinds={ + 'Expr L': ct.FlowKind.Value, + 'Data L': {ct.FlowKind.Info, ct.FlowKind.Params}, + 'Data R': {ct.FlowKind.Info, ct.FlowKind.Params}, + }, + input_sockets_optional={ + 'Expr L': True, + 'Data L': True, + 'Data R': True, + }, ) - def compute_data(self, props: dict, input_sockets: dict): - if self.active_socket_set == 'Elementwise': - # Element-Wise Arithmetic - if props['operation'] == 'ADD': - return input_sockets['Data L'] + input_sockets['Data R'] - if props['operation'] == 'SUB': - return input_sockets['Data L'] - input_sockets['Data R'] - if props['operation'] == 'MUL': - return input_sockets['Data L'] * input_sockets['Data R'] - if props['operation'] == 'DIV': - return input_sockets['Data L'] / input_sockets['Data R'] + def compute_data_params( + self, props, input_sockets + ) -> ct.ParamsFlow | ct.FlowSignal: + expr_l = input_sockets['Expr L'] + data_l_info = input_sockets['Data L'][ct.FlowKind.Info] + data_l_params = input_sockets['Data L'][ct.FlowKind.Params] + data_r_info = input_sockets['Data R'][ct.FlowKind.Info] + data_r_params = input_sockets['Data R'][ct.FlowKind.Params] - # Element-Wise Arithmetic - if props['operation'] == 'POW': - return input_sockets['Data L'] ** input_sockets['Data R'] + has_expr_l = not ct.FlowSignal.check(expr_l) + has_data_l_info = not ct.FlowSignal.check(data_l_info) + has_data_l_params = not ct.FlowSignal.check(data_l_params) + has_data_r_info = not ct.FlowSignal.check(data_r_info) + has_data_r_params = not ct.FlowSignal.check(data_r_params) - # Binary Trigonometry - if props['operation'] == 'ATAN2': - return jnp.atan2(input_sockets['Data L'], input_sockets['Data R']) + #log.critical((props, input_sockets)) - # Special Functions - if props['operation'] == 'HEAVISIDE': - return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R']) + # Compose by Socket Set + ## Data | Data + if ( + has_data_l_info + and has_data_l_params + and has_data_r_info + and has_data_r_params + ): + return data_l_params | data_r_params - # Linear Algebra - if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}: - if props['operation'] == 'DOT': - return jnp.dot(input_sockets['Data L'], input_sockets['Data R']) + ## Expr | Data + if has_expr_l and has_data_r_info and has_data_r_params: + operation = props['operation'] + data_unit = data_r_info.output_unit - elif self.active_socket_set == 'Vec-Vec': - if props['operation'] == 'CROSS': - return jnp.cross(input_sockets['Data L'], input_sockets['Data R']) + # By Operation + ## Add/Sub: Scale to Output Unit + if operation in ['ADD', 'SUB', 'MUL', 'DIV']: + if not spux.uses_units(expr_l): + value = spux.sympy_to_python(expr_l) + else: + value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) - elif self.active_socket_set == 'Mat-Vec': - if props['operation'] == 'LIN_SOLVE': - return jnp.linalg.lstsq( - input_sockets['Data L'], input_sockets['Data R'] - ) - if props['operation'] == 'LSQ_SOLVE': - return jnp.linalg.solve( - input_sockets['Data L'], input_sockets['Data R'] + return data_r_params.compose_within( + enclosing_func_args=[value], ) - msg = 'Invalid operation' - raise ValueError(msg) + ## Pow: Doesn't Exist (?) + ## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function + if operation == 'POW': + return ct.FlowSignal.FlowPending + + ## atan2(): Only Length + ## -> Implicitly presume that Data L/R use length units. + if operation == 'ATAN2': + if not spux.uses_units(expr_l): + value = spux.sympy_to_python(expr_l) + else: + value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) + + return data_r_params.compose_within( + enclosing_func_args=[value], + ) + + return data_r_params.compose_within( + enclosing_func_args=[ + spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) + ] + ) + + return ct.FlowSignal.FlowPending + + #################### + # - Auxiliary: Info + #################### + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.Info, + input_sockets={'Expr L', 'Data L', 'Data R'}, + input_socket_kinds={ + 'Expr L': ct.FlowKind.Value, + 'Data L': ct.FlowKind.Info, + 'Data R': ct.FlowKind.Info, + }, + input_sockets_optional={ + 'Expr L': True, + 'Data L': True, + 'Data R': True, + }, + ) + def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow: + expr_l = input_sockets['Expr L'] + data_l_info = input_sockets['Data L'] + data_r_info = input_sockets['Data R'] + + has_expr_l = not ct.FlowSignal.check(expr_l) + has_data_l_info = not ct.FlowSignal.check(data_l_info) + has_data_r_info = not ct.FlowSignal.check(data_r_info) + + # Info by Socket Set + ## Data | Data + if has_data_l_info and has_data_r_info: + return data_r_info + + ## Expr | Data + if has_expr_l and has_data_r_info: + return data_r_info + + return ct.FlowSignal.FlowPending #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py new file mode 100644 index 0000000..ef45362 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py @@ -0,0 +1,165 @@ +"""Declares `TransformMathNode`.""" + +import enum +import typing as typ + +import bpy +import jax +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux + +from .... import contracts as ct +from .... import sockets +from ... import base, events + +log = logger.get(__name__) + + +class TransformMathNode(base.MaxwellSimNode): + r"""Applies a function to the array as a whole, with arbitrary results. + + The shape, type, and interpretation of the input/output data is dynamically shown. + + # Socket Sets + ## Interpret + Reinterprets the `InfoFlow` of an array, **without changing it**. + + Attributes: + operation: Operation to apply to the input. + """ + + node_type = ct.NodeType.TransformMath + bl_label = 'Transform Math' + + input_sockets: typ.ClassVar = { + 'Data': sockets.DataSocketDef(format='jax'), + } + input_socket_sets: typ.ClassVar = { + 'Fourier': {}, + 'Affine': {}, + 'Convolve': {}, + } + output_sockets: typ.ClassVar = { + 'Data': sockets.DataSocketDef(format='jax'), + } + + #################### + # - Properties + #################### + operation: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_operations() + ) + + def search_operations(self) -> list[ct.BLEnumElement]: + if self.active_socket_set == 'Fourier': # noqa: SIM114 + items = [] + elif self.active_socket_set == 'Affine': # noqa: SIM114 + items = [] + elif self.active_socket_set == 'Convolve': + items = [] + else: + msg = f'Active socket set {self.active_socket_set} is unknown' + raise RuntimeError(msg) + + return [(*item, '', i) for i, item in enumerate(items)] + + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + layout.prop(self, self.blfields['operation'], text='') + + #################### + # - Events + #################### + @events.on_value_changed( + prop_name='active_socket_set', + ) + def on_socket_set_changed(self): + self.operation = bl_cache.Signal.ResetEnumItems + + #################### + # - Compute: LazyValueFunc / Array + #################### + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.LazyValueFunc, + props={'active_socket_set', 'operation'}, + input_sockets={'Data'}, + input_socket_kinds={ + 'Data': ct.FlowKind.LazyValueFunc, + }, + ) + def compute_data(self, props: dict, input_sockets: dict): + has_data = not ct.FlowSignal.check(input_sockets['Data']) + if not has_data or props['operation'] == 'NONE': + return ct.FlowSignal.FlowPending + + mapping_func: typ.Callable[[jax.Array], jax.Array] = { + 'Fourier': {}, + 'Affine': {}, + 'Convolve': {}, + }[props['active_socket_set']][props['operation']] + + # Compose w/Lazy Root Function Data + return input_sockets['Data'].compose_within( + mapping_func, + supports_jax=True, + ) + + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.Array, + output_sockets={'Data'}, + output_socket_kinds={ + 'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, + }, + ) + def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: + lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] + params = output_sockets['Data'][ct.FlowKind.Params] + + if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): + return ct.ArrayFlow( + values=lazy_value_func.func_jax( + *params.func_args, **params.func_kwargs + ), + unit=None, + ) + + return ct.FlowSignal.FlowPending + + #################### + # - Compute Auxiliary: Info / Params + #################### + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.Info, + props={'active_socket_set', 'operation'}, + input_sockets={'Data'}, + input_socket_kinds={'Data': ct.FlowKind.Info}, + ) + def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: + info = input_sockets['Data'] + if ct.FlowSignal.check(info): + return ct.FlowSignal.FlowPending + + return info + + @events.computes_output_socket( + 'Data', + kind=ct.FlowKind.Params, + input_sockets={'Data'}, + input_socket_kinds={'Data': ct.FlowKind.Params}, + ) + def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal: + return input_sockets['Data'] + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + TransformMathNode, +] +BL_NODES = {ct.NodeType.TransformMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index fa5fe78..b7b78cd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -61,32 +61,34 @@ class VizMode(enum.StrEnum): @staticmethod def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: + EMPTY = () + Z = spux.MathType.Integer + R = spux.MathType.Real + VM = VizMode + valid_viz_modes = { - ((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D], - ((spux.MathType.Integer), (spux.MathType.Real)): [ - VizMode.Hist1D, - VizMode.BoxPlot1D, + (EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D], + ((Z), (None, R)): [ + VM.Hist1D, + VM.BoxPlot1D, ], - ((spux.MathType.Real,), (spux.MathType.Real,)): [ - VizMode.Curve2D, - VizMode.Points2D, - VizMode.Bar, + ((R,), (None, R)): [ + VM.Curve2D, + VM.Points2D, + VM.Bar, ], - ((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [ - VizMode.Curves2D, - VizMode.FilledCurves2D, + ((R, Z), (None, R)): [ + VM.Curves2D, + VM.FilledCurves2D, ], - ((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [ - VizMode.Heatmap2D, + ((R, R), (None, R)): [ + VM.Heatmap2D, ], - ( - (spux.MathType.Real, spux.MathType.Real, spux.MathType.Real), - (spux.MathType.Real,), - ): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D], + ((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D], }.get( ( tuple(info.dim_mathtypes.values()), - tuple(info.output_mathtypes.values()), + (info.output_shape, info.output_mathtype), ) ) @@ -161,10 +163,10 @@ class VizTarget(enum.StrEnum): @staticmethod def to_name(value: typ.Self) -> str: return { - VizTarget.Plot2D: 'Image (Plot)', - VizTarget.Pixels: 'Image (Pixels)', - VizTarget.PixelsPlane: 'Image (Plane)', - VizTarget.Voxels: '3D Field', + VizTarget.Plot2D: 'Plot', + VizTarget.Pixels: 'Pixels', + VizTarget.PixelsPlane: 'Image Plane', + VizTarget.Voxels: 'Voxels', }[value] @staticmethod diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index 6542610..1e0b844 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -1036,6 +1036,12 @@ class MaxwellSimNode(bpy.types.Node): # Generate New Instance ID self.reset_instance_id() + # Generate New Instance ID for Sockets + ## Sockets can't do this themselves. + for bl_sockets in [self.inputs, self.outputs]: + for bl_socket in bl_sockets: + bl_socket.reset_instance_id() + # Generate New Sim Node Name ## Blender will automatically add .001 so that `self.name` is unique. self.sim_node_name = self.name diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py index 2b0a5d9..b348d8c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py @@ -14,7 +14,7 @@ class ScientificConstantNode(base.MaxwellSimNode): bl_label = 'Scientific Constant' output_sockets: typ.ClassVar = { - 'Value': sockets.AnySocketDef(), + 'Value': sockets.ExprSocketDef(), } #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index e7df2b5..d77fb34 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -913,10 +913,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket): col = layout.column() row = col.row() row.alignment = 'RIGHT' - if self.is_linked: - self.draw_output_label_row(row, text) - else: - row.label(text=text) + self.draw_output_label_row(row, text) # Draw FlowKind.Info related Information if self.use_info_draw: diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py index a8f5c03..c882fdf 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py @@ -12,6 +12,10 @@ from .. import base log = logger.get(__name__) +def unicode_superscript(n): + return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)]) + + class DataInfoColumn(enum.StrEnum): Length = enum.auto() MathType = enum.auto() @@ -49,11 +53,11 @@ class DataBLSocket(base.MaxwellSimSocket): ## TODO: typ.Literal['xarray', 'jax'] show_info_columns: bool = bl_cache.BLField( - False, + True, prop_ui=True, ) info_columns: DataInfoColumn = bl_cache.BLField( - {DataInfoColumn.MathType, DataInfoColumn.Length}, prop_ui=True, enum_many=True + {DataInfoColumn.MathType, DataInfoColumn.Unit}, prop_ui=True, enum_many=True ) #################### @@ -71,8 +75,10 @@ class DataBLSocket(base.MaxwellSimSocket): # - UI #################### def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None: - if self.format == 'jax': - row.label(text=text) + row.label(text=text) + + info = self.compute_data(kind=ct.FlowKind.Info) + if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names: row.prop(self, self.blfields['info_columns']) row.prop( self, @@ -83,7 +89,8 @@ class DataBLSocket(base.MaxwellSimSocket): ) def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None: - if self.format == 'jax': + info = self.compute_data(kind=ct.FlowKind.Info) + if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names: row.prop( self, self.blfields['show_info_columns'], @@ -92,7 +99,8 @@ class DataBLSocket(base.MaxwellSimSocket): icon=ct.Icon.ToggleSocketInfo, ) row.prop(self, self.blfields['info_columns']) - row.label(text=text) + + row.label(text=text) def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None: if self.format == 'jax' and info.dim_names and self.show_info_columns: @@ -118,16 +126,27 @@ class DataBLSocket(base.MaxwellSimSocket): grid.label(text=spux.sp_to_str(dim_idx.unit)) # Outputs - for output_name in info.output_names: - grid.label(text=output_name) - if DataInfoColumn.Length in self.info_columns: - grid.label(text='', icon=ct.Icon.DataSocketOutput) - if DataInfoColumn.MathType in self.info_columns: - grid.label( - text=spux.MathType.to_str(info.output_mathtypes[output_name]) + grid.label(text=info.output_name) + if DataInfoColumn.Length in self.info_columns: + grid.label(text='', icon=ct.Icon.DataSocketOutput) + if DataInfoColumn.MathType in self.info_columns: + grid.label( + text=( + spux.MathType.to_str(info.output_mathtype) + + ( + 'ˣ'.join( + [ + unicode_superscript(out_axis) + for out_axis in info.output_shape + ] + ) + if info.output_shape + else '' + ) ) - if DataInfoColumn.Unit in self.info_columns: - grid.label(text=spux.sp_to_str(info.output_units[output_name])) + ) + if DataInfoColumn.Unit in self.info_columns: + grid.label(text=f'{spux.sp_to_str(info.output_unit)}') #################### @@ -137,9 +156,11 @@ class DataSocketDef(base.SocketDef): socket_type: ct.SocketType = ct.SocketType.Data format: typ.Literal['xarray', 'jax', 'monitor_data'] + default_show_info_columns: bool = True def init(self, bl_socket: DataBLSocket) -> None: bl_socket.format = self.format + bl_socket.default_show_info_columns = self.default_show_info_columns #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py index 94ea61d..5f3cd8c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py @@ -86,9 +86,7 @@ class ExprBLSocket(base.MaxwellSimSocket): def lazy_value_func(self) -> ct.LazyValueFuncFlow: return ct.LazyValueFuncFlow( func=sp.lambdify(self.symbols, self.value, 'jax'), - func_args=[ - (sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols - ], + func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols], supports_jax=True, ) diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index a3de641..78d57e9 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -31,6 +31,18 @@ class MathType(enum.StrEnum): Real = enum.auto() Complex = enum.auto() + def combine(*mathtypes: list[typ.Self]) -> typ.Self: + if MathType.Complex in mathtypes: + return MathType.Complex + elif MathType.Real in mathtypes: + return MathType.Real + elif MathType.Rational in mathtypes: + return MathType.Rational + elif MathType.Integer in mathtypes: + return MathType.Integer + elif MathType.Bool in mathtypes: + return MathType.Bool + @staticmethod def from_expr(sp_obj: SympyType) -> type: if isinstance(sp_obj, sp.logic.boolalg.Boolean): @@ -52,13 +64,13 @@ class MathType(enum.StrEnum): int: MathType.Integer, float: MathType.Real, complex: MathType.Complex, - #jnp.int32: MathType.Integer, - #jnp.int64: MathType.Integer, - #jnp.float32: MathType.Real, - #jnp.float64: MathType.Real, - #jnp.complex64: MathType.Complex, - #jnp.complex128: MathType.Complex, - #jnp.bool_: MathType.Bool, + # jnp.int32: MathType.Integer, + # jnp.int64: MathType.Integer, + # jnp.float32: MathType.Real, + # jnp.float64: MathType.Real, + # jnp.complex64: MathType.Complex, + # jnp.complex128: MathType.Complex, + # jnp.bool_: MathType.Bool, }[dtype] @staticmethod @@ -595,6 +607,21 @@ ComplexNumber: typ.TypeAlias = ConstrSympyExpr( ) Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber +# Number +PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'scalar'}, +) +PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + allowed_structures={'scalar'}, +) +PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber + # Vector Real3DVector: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 6f64ada..1fdf328 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -117,8 +117,8 @@ def rgba_image_from_2d_map( def plot_hist_1d( data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis ) -> None: - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit ax.hist(data, bins=30, alpha=0.75) ax.set_title('Histogram') @@ -130,8 +130,8 @@ def plot_box_plot_1d( data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis ) -> None: x_name = info.dim_names[0] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit ax.boxplot(data) ax.set_title('Box Plot') @@ -147,8 +147,8 @@ def plot_curve_2d( x_name = info.dim_names[0] x_unit = info.dim_units[x_name] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit times.append(time.perf_counter() - times[0]) ax.plot(info.dim_idx_arrays[0], data) @@ -167,8 +167,8 @@ def plot_points_2d( ) -> None: x_name = info.dim_names[0] x_unit = info.dim_units[x_name] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) ax.set_title('2D Points') @@ -179,8 +179,8 @@ def plot_points_2d( def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: x_name = info.dim_names[0] x_unit = info.dim_units[x_name] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit ax.bar(info.dim_idx_arrays[0], data, alpha=0.7) ax.set_title('2D Bar') @@ -194,8 +194,8 @@ def plot_curves_2d( ) -> None: x_name = info.dim_names[0] x_unit = info.dim_units[x_name] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit for category in range(data.shape[1]): ax.plot(data[:, 0], data[:, 1]) @@ -211,8 +211,8 @@ def plot_filled_curves_2d( ) -> None: x_name = info.dim_names[0] x_unit = info.dim_units[x_name] - y_name = info.output_names[0] - y_unit = info.output_units[y_name] + y_name = info.output_name + y_unit = info.output_unit ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1]) ax.set_title('2D Curves')