diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py index 945516e..4bd1c5d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py @@ -126,6 +126,10 @@ class ArrayFlow: def __len__(self) -> int: return len(self.values) + @functools.cached_property + def mathtype(self) -> spux.MathType: + return spux.MathType.from_pytype(type(self.values.item(0))) + def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: """Find the index of the value that is closest to the given value. @@ -437,6 +441,27 @@ class LazyArrayRangeFlow: key=lambda sym: sym.name, ) + @functools.cached_property + def mathtype(self) -> spux.MathType: + # Get Start Mathtype + if isinstance(self.start, spux.SympyType): + start_mathtype = spux.MathType.from_expr(self.start) + else: + start_mathtype = spux.MathType.from_pytype(self.start) + + # Get Stop Mathtype + if isinstance(self.stop, spux.SympyType): + stop_mathtype = spux.MathType.from_expr(type(self.stop)) + else: + stop_mathtype = spux.MathType.from_pytype(type(self.stop)) + + # Check Equal + if start_mathtype != stop_mathtype: + msg = "Mathtypes of start and stop don't agree. Please fix!" + raise ValueError(msg) + + return start_mathtype + def __len__(self): return self.steps @@ -688,4 +713,31 @@ class InfoFlow: default_factory=dict ) ## TODO: Rename to dim_idxs - ## TODO: Validation, esp. length of dims. Pydantic? + @functools.cached_property + def dim_lens(self) -> dict[str, int]: + return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} + + @functools.cached_property + def dim_mathtypes(self) -> dict[str, int]: + return { + dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() + } + + @functools.cached_property + def dim_units(self) -> dict[str, int]: + return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} + + @functools.cached_property + def dim_idx_arrays(self) -> list[ArrayFlow]: + return [ + dim_idx.realize().values + if isinstance(dim_idx, LazyArrayRangeFlow) + else dim_idx.values + for dim_idx in self.dim_idx.values() + ] + return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} + + # Output Information + output_names: list[str] = dataclasses.field(default_factory=list) + output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict) + output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index b90fd69..c686d65 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -1,4 +1,5 @@ import io +import time import typing as typ import bpy @@ -205,61 +206,57 @@ class ManagedBLImage(base.ManagedObj): dpi: int | None = None, bl_select: bool = False, ): - # time_start = time.perf_counter() + times = [time.perf_counter()] import matplotlib.pyplot as plt - # log.debug('Imported PyPlot (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Compute Plot Dimensions aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = ( self.gen_image_geometry(width_inches, height_inches, dpi) ) - # log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) - # log.debug( - # 'Creating MPL Axes (aspect=%f, width=%f, height=%f)', - # aspect_ratio, - # _width_inches, - # _height_inches, - # ) # Create MPL Figure, Axes, and Compute Figure Geometry fig, ax = plt.subplots( figsize=[_width_inches, _height_inches], dpi=_dpi, ) - # log.debug('Created MPL Axes (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) ax.set_aspect(aspect_ratio) + times.append(time.perf_counter() - times[0]) cmp_width_px, cmp_height_px = fig.canvas.get_width_height() - ## Use computed pixel w/h to preempt off-by-one size errors. + times.append(time.perf_counter() - times[0]) ax.set_aspect('auto') ## Workaround aspect-ratio bugs - # log.debug('Set MPL Aspect (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Plot w/User Parameter func_plotter(ax) - # log.debug('User Plot Function (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Save Figure to BytesIO with io.BytesIO() as buff: - # log.debug('Made BytesIO (%f)', time.perf_counter() - time_start) fig.savefig(buff, format='raw', dpi=dpi) - # log.debug('Saved Figure to BytesIO (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) buff.seek(0) image_data = np.frombuffer( buff.getvalue(), dtype=np.uint8, ).reshape([cmp_height_px, cmp_width_px, -1]) - # log.debug('Set Image Data (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) image_data = np.flipud(image_data).astype(np.float32) / 255 - # log.debug('Flipped Image Data (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) plt.close(fig) # Optimized Write to Blender Image bl_image = self.bl_image(cmp_width_px, cmp_height_px, 'RGBA', 'uint8') - # log.debug('Made BL Image (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) bl_image.pixels.foreach_set(image_data.ravel()) - # log.debug('Set BL Image Pixels (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) bl_image.update() - # log.debug('Updated BL Image (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) if bl_select: self.bl_select() + times.append(time.perf_counter() - times[0]) + #log.critical('Timing of MPL Plot: %s', str(times)) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index eea08fa..bbc428a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -1,12 +1,13 @@ +import enum import typing as typ -import enum import bpy import jax import jax.numpy as jnp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct from ... import sockets @@ -45,11 +46,13 @@ class ExtractDataNode(base.MaxwellSimNode): ) # Sim Data - sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField({}) + sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField( + {}, use_prop_update=False + ) # Monitor Data - monitor_data_type: str = bl_cache.BLField('') - monitor_data_components: list[str] = bl_cache.BLField([]) + monitor_data_type: str = bl_cache.BLField('', use_prop_update=False) + monitor_data_components: list[str] = bl_cache.BLField([], use_prop_update=False) #################### # - Computed Properties @@ -177,7 +180,7 @@ class ExtractDataNode(base.MaxwellSimNode): self.extract_filter = bl_cache.Signal.ResetEnumItems #################### - # - Output: Value + # - Output: Sim Data -> Monitor Data #################### @events.computes_output_socket( 'Monitor Data', @@ -186,34 +189,52 @@ class ExtractDataNode(base.MaxwellSimNode): input_sockets={'Sim Data'}, ) def compute_monitor_data(self, props: dict, input_sockets: dict): - return input_sockets['Sim Data'].monitor_data[props['extract_filter']] + if input_sockets['Sim Data'] is not None and props['extract_filter'] != 'NONE': + return input_sockets['Sim Data'].monitor_data[props['extract_filter']] + return None + + #################### + # - Output: Monitor Data -> Data + #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.Array, props={'extract_filter'}, input_sockets={'Monitor Data'}, + input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, ) - def compute_data(self, props: dict, input_sockets: dict) -> jax.Array: - xarray_data = getattr(input_sockets['Monitor Data'], props['extract_filter']) - return jnp.array(xarray_data.data) ## TODO: Can it be done without a copy? + def compute_data(self, props: dict, input_sockets: dict) -> jax.Array | None: + if ( + input_sockets['Monitor Data'] is not None + and props['extract_filter'] != 'NONE' + ): + xarray_data = getattr( + input_sockets['Monitor Data'], props['extract_filter'] + ) + return jnp.array(xarray_data.data) + ## TODO: Let the array itself have its output unit too! + + return None - #################### - # - Output: LazyValueFunc - #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.LazyValueFunc, output_sockets={'Data'}, output_socket_kinds={'Data': ct.FlowKind.Array}, ) - def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.LazyValueFuncFlow: - return ct.LazyValueFuncFlow( - func=lambda: output_sockets['Data'], supports_jax=True - ) + def compute_extracted_data_lazy( + self, output_sockets: dict + ) -> ct.LazyValueFuncFlow | None: + if output_sockets['Data'] is not None: + return ct.LazyValueFuncFlow( + func=lambda: output_sockets['Data'], supports_jax=True + ) + + return None #################### - # - Output: Info + # - Auxiliary: Monitor Data -> Data #################### @events.computes_output_socket( 'Data', @@ -221,6 +242,7 @@ class ExtractDataNode(base.MaxwellSimNode): props={'monitor_data_type', 'extract_filter'}, input_sockets={'Monitor Data'}, input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, + input_sockets_optional={'Monitor Data': True}, ) def compute_extracted_data_info( self, props: dict, input_sockets: dict @@ -234,12 +256,16 @@ class ExtractDataNode(base.MaxwellSimNode): else: return ct.InfoFlow() + info_output_names = { + 'output_names': [props['extract_filter']], + } + # Compute InfoFlow from XArray ## XYZF: Field / Permittivity / FieldProjectionCartesian if props['monitor_data_type'] in { 'Field', 'Permittivity', - 'FieldProjectionCartesian', + #'FieldProjectionCartesian', }: return ct.InfoFlow( dim_names=['x', 'y', 'z', 'f'], @@ -256,6 +282,13 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Complex}, + output_units={ + props['extract_filter']: spu.volt / spu.micrometer + if props['monitor_data_type'] == 'Field' + else None + }, ) ## XYZT: FieldTime @@ -275,6 +308,17 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Complex}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + if props['monitor_data_type'] == 'Field' + else None + }, ) ## F: Flux @@ -288,6 +332,9 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={props['extract_filter']: spu.watt}, ) ## T: FluxTime @@ -301,6 +348,9 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={props['extract_filter']: spu.watt}, ) ## RThetaPhiF: FieldProjectionAngle @@ -327,6 +377,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) ## UxUyRF: FieldProjectionKSpace @@ -351,6 +410,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) ## OrderxOrderyF: Diffraction @@ -372,6 +440,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 6c5fac0..18e6a98 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -47,8 +47,9 @@ class FilterMathNode(base.MaxwellSimNode): None, prop_ui=True, enum_cb=lambda self, _: self.search_dims() ) - dim_names: list[str] = bl_cache.BLField([]) - dim_lens: dict[str, int] = bl_cache.BLField({}) + @property + def _info(self) -> ct.InfoFlow: + return self._compute_input('Data', kind=ct.FlowKind.Info) #################### # - Operation Search @@ -70,16 +71,16 @@ class FilterMathNode(base.MaxwellSimNode): # - Dim Search #################### def search_dims(self) -> list[ct.BLEnumElement]: - if self.dim_names: + if (info := self._info).dim_names: dims = [ (dim_name, dim_name, dim_name, '', i) - for i, dim_name in enumerate(self.dim_names) + for i, dim_name in enumerate(info.dim_names) ] # Squeeze: Dimension Must Have Length=1 ## We must also correct the "NUMBER" of the enum. if self.operation == 'SQUEEZE': - filtered_dims = [dim for dim in dims if self.dim_lens[dim[0]] == 1] + filtered_dims = [dim for dim in dims if info.dim_lens[dim[0]] == 1] return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)] return dims @@ -90,7 +91,7 @@ class FilterMathNode(base.MaxwellSimNode): #################### def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') - if self.dim_names: + if self._info.dim_names: layout.prop(self, self.blfields['dim'], text='') #################### @@ -98,52 +99,49 @@ class FilterMathNode(base.MaxwellSimNode): #################### @events.on_value_changed( prop_name='active_socket_set', + run_on_init=True, ) def on_socket_set_changed(self): self.operation = bl_cache.Signal.ResetEnumItems @events.on_value_changed( - socket_name={'Data'}, - prop_name={'active_socket_set'}, + socket_name='Data', + prop_name='active_socket_set', props={'active_socket_set'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, - input_sockets_optional={'Data': True}, - run_on_init=True, + # run_on_init=True, ) def on_any_change(self, props: dict, input_sockets: dict): - # Set Dimension Names from InfoFlow - if input_sockets['Data'].dim_names: - self.dim_names = input_sockets['Data'].dim_names - self.dim_lens = { - dim_name: len(dim_idx) - for dim_name, dim_idx in input_sockets['Data'].dim_idx.items() - } - else: - self.dim_names = [] - self.dim_lens = {} - - # Reset Enum self.dim = bl_cache.Signal.ResetEnumItems @events.on_value_changed( + socket_name='Data', prop_name='dim', props={'active_socket_set', 'dim'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, - input_sockets_optional={'Data': True}, + # run_on_init=True, ) def on_dim_change(self, props: dict, input_sockets: dict): # Add/Remove Input Socket "Value" - if props['active_socket_set'] == 'By Dim Value' and props['dim'] != 'NONE': + if ( + input_sockets['Data'] != ct.InfoFlow() + and props['active_socket_set'] == 'By Dim Value' + and props['dim'] != 'NONE' + ): # Get Current and Wanted Socket Defs - current_socket_def = self.loose_input_sockets.get('Value') + current_bl_socket = self.loose_input_sockets.get('Value') wanted_socket_def = sockets.SOCKET_DEFS[ ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit) ] # Determine Whether to Declare New Loose Input SOcket - if current_socket_def is None or current_socket_def != wanted_socket_def: + if ( + current_bl_socket is None + or sockets.SOCKET_DEFS[current_bl_socket.socket_type] + != wanted_socket_def + ): self.loose_input_sockets = { 'Value': wanted_socket_def(), } @@ -225,7 +223,7 @@ class FilterMathNode(base.MaxwellSimNode): # Compute Bound/Free Parameters ## Empty Dimension -> Empty InfoFlow - if props['dim'] != 'NONE': + if input_sockets['Data'] != ct.InfoFlow() and props['dim'] != 'NONE': axis = info.dim_names.index(props['dim']) else: return ct.InfoFlow() @@ -243,6 +241,9 @@ class FilterMathNode(base.MaxwellSimNode): for dim_name, dim_idx in info.dim_idx.items() if dim_name != props['dim'] }, + output_names=info.output_names, + output_mathtypes=info.output_mathtypes, + output_units=info.output_units, ) # Fallback to Empty InfoFlow diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 731805b..57b250b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -6,7 +6,8 @@ import jax import jax.numpy as jnp import sympy as sp -from blender_maxwell.utils import logger, bl_cache +from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct from .... import sockets @@ -143,7 +144,7 @@ class MapMathNode(base.MaxwellSimNode): 'SINC': lambda data: jnp.sinc(data), }, 'By Vector': { - 'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1), + 'NORM_2': lambda data: jnp.linalg.norm(data, ord=2, axis=-1), }, 'By Matrix': { # Matrix -> Number @@ -196,11 +197,34 @@ class MapMathNode(base.MaxwellSimNode): @events.computes_output_socket( 'Data', kind=ct.FlowKind.Info, + props={'active_socket_set', 'operation'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, ) - def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow: - return input_sockets['Data'] + def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: + info = input_sockets['Data'] + + # Complex -> Real + if props['active_socket_set'] == 'By Element' and props['operation'] in [ + 'REAL', + 'IMAG', + 'ABS', + ]: + return ct.InfoFlow( + dim_names=info.dim_names, + dim_idx=info.dim_idx, + output_names=info.output_names, + output_mathtypes={ + output_name: ( + spux.MathType.Real + if output_mathtype == spux.MathType.Complex + else output_mathtype + ) + for output_name, output_mathtype in info.output_mathtypes.items() + }, + output_units=info.output_units, + ) + return info @events.computes_output_socket( 'Data', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 21d65f6..469ae56 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -2,8 +2,11 @@ import enum import typing as typ import bpy +import jaxtyping as jtyp +import matplotlib.axis as mpl_ax from blender_maxwell.utils import bl_cache, image_ops, logger +from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct from ... import managed_objs, sockets @@ -12,9 +15,168 @@ from .. import base, events log = logger.get(__name__) +class VizMode(enum.StrEnum): + """Available visualization modes. + + **NOTE**: >1D output dimensions currently have no viz. + + Plots for `() -> ℝ`: + - Hist1D: Bin-summed distribution. + - BoxPlot1D: Box-plot describing the distribution. + + Plots for `(ℤ) -> ℝ`: + - BoxPlots1D: Side-by-side boxplots w/equal y axis. + + Plots for `(ℝ) -> ℝ`: + - Curve2D: Standard line-curve w/smooth interpolation + - Points2D: Scatterplot of individual points. + - Bar: Value to height of a barplot. + + Plots for `(ℝ, ℤ) -> ℝ`: + - Curves2D: Layered Curve2Ds with unique colors. + - FilledCurves2D: Layered Curve2Ds with filled space between. + + Plots for `(ℝ, ℝ) -> ℝ`: + - Heatmap2D: Colormapped image with value at each pixel. + + Plots for `(ℝ, ℝ, ℝ) -> ℝ`: + - SqueezedHeatmap2D: 3D-embeddable heatmap for when one of the axes is 1. + - Heatmap3D: Colormapped field with value at each voxel. + """ + + Hist1D = enum.auto() + BoxPlot1D = enum.auto() + + Curve2D = enum.auto() + Points2D = enum.auto() + Bar = enum.auto() + + Curves2D = enum.auto() + FilledCurves2D = enum.auto() + + Heatmap2D = enum.auto() + + SqueezedHeatmap2D = enum.auto() + Heatmap3D = enum.auto() + + @staticmethod + def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: + valid_viz_modes = { + ((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D], + ((spux.MathType.Integer), (spux.MathType.Real)): [ + VizMode.Hist1D, + VizMode.BoxPlot1D, + ], + ((spux.MathType.Real,), (spux.MathType.Real,)): [ + VizMode.Curve2D, + VizMode.Points2D, + VizMode.Bar, + ], + ((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [ + VizMode.Curves2D, + VizMode.FilledCurves2D, + ], + ((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [ + VizMode.Heatmap2D, + ], + ( + (spux.MathType.Real, spux.MathType.Real, spux.MathType.Real), + (spux.MathType.Real,), + ): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D], + }.get( + ( + tuple(info.dim_mathtypes.values()), + tuple(info.output_mathtypes.values()), + ) + ) + + if valid_viz_modes is None: + return [] + + return valid_viz_modes + + @staticmethod + def to_plotter( + value: typ.Self, + ) -> typ.Callable[ + [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None + ]: + return { + VizMode.Hist1D: image_ops.plot_hist_1d, + VizMode.BoxPlot1D: image_ops.plot_box_plot_1d, + VizMode.Curve2D: image_ops.plot_curve_2d, + VizMode.Points2D: image_ops.plot_points_2d, + VizMode.Bar: image_ops.plot_bar, + VizMode.Curves2D: image_ops.plot_curves_2d, + VizMode.FilledCurves2D: image_ops.plot_filled_curves_2d, + VizMode.Heatmap2D: image_ops.plot_heatmap_2d, + # NO PLOTTER: VizMode.SqueezedHeatmap2D + # NO PLOTTER: VizMode.Heatmap3D + }[value] + + @staticmethod + def to_name(value: typ.Self) -> str: + return { + VizMode.Hist1D: 'Histogram', + VizMode.BoxPlot1D: 'Box Plot', + VizMode.Curve2D: 'Curve', + VizMode.Points2D: 'Points', + VizMode.Bar: 'Bar', + VizMode.Curves2D: 'Curves', + VizMode.FilledCurves2D: 'Filled Curves', + VizMode.Heatmap2D: 'Heatmap', + VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)', + VizMode.Heatmap3D: 'Heatmap (3D)', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> ct.BLIcon: + return '' + + +class VizTarget(enum.StrEnum): + """Available visualization targets.""" + + Plot2D = enum.auto() + Pixels = enum.auto() + PixelsPlane = enum.auto() + Voxels = enum.auto() + + @staticmethod + def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: + return { + 'NONE': [], + VizMode.Hist1D: [VizTarget.Plot2D], + VizMode.BoxPlot1D: [VizTarget.Plot2D], + VizMode.Curve2D: [VizTarget.Plot2D], + VizMode.Points2D: [VizTarget.Plot2D], + VizMode.Bar: [VizTarget.Plot2D], + VizMode.Curves2D: [VizTarget.Plot2D], + VizMode.FilledCurves2D: [VizTarget.Plot2D], + VizMode.Heatmap2D: [VizTarget.Plot2D, VizTarget.Pixels], + VizMode.SqueezedHeatmap2D: [VizTarget.Pixels, VizTarget.PixelsPlane], + VizMode.Heatmap3D: [VizTarget.Voxels], + }[viz_mode] + + @staticmethod + def to_name(value: typ.Self) -> str: + return { + VizTarget.Plot2D: 'Image (Plot)', + VizTarget.Pixels: 'Image (Pixels)', + VizTarget.PixelsPlane: 'Image (Plane)', + VizTarget.Voxels: '3D Field', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> ct.BLIcon: + return '' + + class VizNode(base.MaxwellSimNode): """Node for visualizing simulation data, by querying its monitors. + Auto-detects the correct plot type based on the input data: + Attributes: colormap: Colormap to apply to 0..1 output. @@ -40,24 +202,91 @@ class VizNode(base.MaxwellSimNode): ##################### ## - Properties ##################### + viz_mode: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_modes() + ) + viz_target: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_targets() + ) + + # Mode-Dependent Properties colormap: image_ops.Colormap = bl_cache.BLField( image_ops.Colormap.Viridis, prop_ui=True ) + ##################### + ## - Mode Searcher + ##################### + @property + def _info(self) -> ct.InfoFlow: + return self._compute_input('Data', kind=ct.FlowKind.Info) + + def search_modes(self) -> list[ct.BLEnumElement]: + info = self._info + return [ + ( + viz_mode, + VizMode.to_name(viz_mode), + VizMode.to_name(viz_mode), + VizMode.to_icon(viz_mode), + i, + ) + for i, viz_mode in enumerate(VizMode.valid_modes_for(info)) + ] + + ##################### + ## - Target Searcher + ##################### + def search_targets(self) -> list[ct.BLEnumElement]: + return [ + ( + viz_target, + VizTarget.to_name(viz_target), + VizTarget.to_name(viz_target), + VizTarget.to_icon(viz_target), + i, + ) + for i, viz_target in enumerate(VizTarget.valid_targets_for(self.viz_mode)) + ] + ##################### ## - UI ##################### def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): - col.prop(self, self.blfields['colormap'], text='') + col.prop(self, self.blfields['viz_mode'], text='') + col.prop(self, self.blfields['viz_target'], text='') + if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]: + col.prop(self, self.blfields['colormap'], text='') + + #################### + # - Events + #################### + @events.on_value_changed( + socket_name='Data', + input_sockets={'Data'}, + input_socket_kinds={'Data': ct.FlowKind.Info}, + input_sockets_optional={'Data': True}, + run_on_init=True, + ) + def on_socket_set_changed(self, input_sockets: dict): + self.viz_mode = bl_cache.Signal.ResetEnumItems + self.viz_target = bl_cache.Signal.ResetEnumItems + + @events.on_value_changed( + prop_name='viz_mode', + # run_on_init=True, + ) + def on_viz_mode_changed(self): + self.viz_target = bl_cache.Signal.ResetEnumItems ##################### ## - Plotting ##################### @events.on_show_plot( managed_objs={'plot'}, - props={'colormap'}, + props={'viz_mode', 'viz_target', 'colormap'}, input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Array}, + input_socket_kinds={'Data': {ct.FlowKind.Array, ct.FlowKind.Info}}, input_sockets_optional={'Data': True}, stop_propagation=True, ) @@ -67,13 +296,32 @@ class VizNode(base.MaxwellSimNode): input_sockets: dict, props: dict, ): - if input_sockets['Data'] is not None: + array_flow = input_sockets['Data'][ct.FlowKind.Array] + info = input_sockets['Data'][ct.FlowKind.Info] + + if input_sockets['Data'] is None: + return + + if props['viz_target'] == VizTarget.Plot2D: + managed_objs['plot'].mpl_plot_to_image( + lambda ax: VizMode.to_plotter(props['viz_mode'])( + array_flow.values, info, ax + ), + bl_select=True, + ) + if props['viz_target'] == VizTarget.Pixels: managed_objs['plot'].map_2d_to_image( - input_sockets['Data'].values, + array_flow.values, colormap=props['colormap'], bl_select=True, ) + if props['viz_target'] == VizTarget.PixelsPlane: + raise NotImplementedError + + if props['viz_target'] == VizTarget.Voxels: + raise NotImplementedError + #################### # - Blender Registration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index 85f0afc..3029b1a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -843,10 +843,12 @@ class MaxwellSimNode(bpy.types.Node): # Propagate Event to All Sockets in "Trigger Direction" ## The trigger chain goes node/socket/node/socket/... if not stop_propagation: - triggered_sockets = self._bl_sockets( - direc=ct.FlowEvent.flow_direction[event] - ) + direc = ct.FlowEvent.flow_direction[event] + triggered_sockets = self._bl_sockets(direc=direc) for bl_socket in triggered_sockets: + if direc == 'output' and not bl_socket.is_linked: + continue + # log.critical( # '![%s] Propagating: (%s, %s)', # self.sim_node_name, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index b58e66b..73785db 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -505,10 +505,10 @@ class MaxwellSimSocket(bpy.types.NodeSocket): socket_name=self.name, socket_kinds=socket_kinds, ) - - self.node.trigger_event( - event, socket_name=self.name, socket_kinds=socket_kinds - ) + else: + self.node.trigger_event( + event, socket_name=self.name, socket_kinds=socket_kinds + ) # Output Socket | Input Flow if self.is_output and flow_direction == 'input': diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py index 1ed0d08..fcf55a8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py @@ -49,14 +49,14 @@ class DataBLSocket(base.MaxwellSimSocket): columns=3, row_major=True, even_columns=True, - #even_rows=True, + # even_rows=True, align=True, ) # Grid Header - #grid.label(text='Dim') - #grid.label(text='Len') - #grid.label(text='Unit') + # grid.label(text='Dim') + # grid.label(text='Len') + # grid.label(text='Unit') # Dimension Names for dim_name in info.dim_names: diff --git a/src/blender_maxwell/utils/bl_cache.py b/src/blender_maxwell/utils/bl_cache.py index 344daaf..74ceaa5 100644 --- a/src/blender_maxwell/utils/bl_cache.py +++ b/src/blender_maxwell/utils/bl_cache.py @@ -828,7 +828,6 @@ class BLField: current_items = self._enum_cb(bl_instance, None) # Only Change if Changes Need Making - ## if old_items != current_items: # Set Enum to First Item ## Prevents the seemingly "missing" enum element bug. @@ -837,7 +836,7 @@ class BLField: ## -> Infinite recursion if we don't check current value. ## -> May cause a hiccup (chains will trigger twice) ## To work, there **must** be a guaranteed-available string at 0,0. - first_old_value = old_items(bl_instance, None)[0][0] + first_old_value = old_items[0][0] current_value = self._cached_bl_property.__get__( bl_instance, bl_instance.__class__ ) diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index 28f6736..14ad141 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -10,9 +10,11 @@ Attributes: Should be used via the `ConstrSympyExpr`, which also adds expression validation. """ +import enum import itertools import typing as typ +import jax.numpy as jnp import pydantic as pyd import sympy as sp import sympy.physics.units as spu @@ -22,6 +24,54 @@ from pydantic_core import core_schema as pyd_core_schema SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix | spu.Quantity +class MathType(enum.StrEnum): + Bool = enum.auto() + Integer = enum.auto() + Rational = enum.auto() + Real = enum.auto() + Complex = enum.auto() + + @staticmethod + def from_expr(sp_obj: SympyType) -> type: + if isinstance(sp_obj, sp.logic.boolalg.Boolean): + return MathType.Bool + if sp_obj.is_integer: + return MathType.Integer + if sp_obj.is_rational or sp_obj.is_real: + return MathType.Real + if sp_obj.is_complex: + return MathType.Complex + + msg = "Can't determine MathType from sympy object: {sp_obj}" + raise ValueError(msg) + + @staticmethod + def from_pytype(dtype) -> type: + return { + bool: MathType.Bool, + int: MathType.Integer, + float: MathType.Real, + complex: MathType.Complex, + #jnp.int32: MathType.Integer, + #jnp.int64: MathType.Integer, + #jnp.float32: MathType.Real, + #jnp.float64: MathType.Real, + #jnp.complex64: MathType.Complex, + #jnp.complex128: MathType.Complex, + #jnp.bool_: MathType.Bool, + }[dtype] + + @staticmethod + def to_dtype(value: typ.Self) -> type: + return { + MathType.Bool: bool, + MathType.Integer: int, + MathType.Rational: float, + MathType.Real: float, + MathType.Complex: complex, + }[value] + + #################### # - Units #################### diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 260aec1..6f64ada 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -1,12 +1,14 @@ """Useful image processing operations for use in the addon.""" import enum +import time import typing as typ import jax import jax.numpy as jnp import jaxtyping as jtyp import matplotlib +import matplotlib.axis as mpl_ax from blender_maxwell import contracts as ct from blender_maxwell.utils import logger @@ -106,3 +108,129 @@ def rgba_image_from_2d_map( return rgba_image_from_2d_map__grayscale(map_2d) return rgba_image_from_2d_map__grayscale(map_2d) + + +#################### +# - Plotters +#################### +# () -> ℝ +def plot_hist_1d( + data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis +) -> None: + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.hist(data, bins=30, alpha=0.75) + ax.set_title('Histogram') + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℤ) -> ℝ +def plot_box_plot_1d( + data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.boxplot(data) + ax.set_title('Box Plot') + ax.set_xlabel(f'{x_name}') + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ) -> ℝ +def plot_curve_2d( + data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis +) -> None: + times = [time.perf_counter()] + + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + times.append(time.perf_counter() - times[0]) + ax.plot(info.dim_idx_arrays[0], data) + times.append(time.perf_counter() - times[0]) + ax.set_title('2D Curve') + times.append(time.perf_counter() - times[0]) + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + times.append(time.perf_counter() - times[0]) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + times.append(time.perf_counter() - times[0]) + # log.critical('Timing of Curve2D: %s', str(times)) + + +def plot_points_2d( + data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) + ax.set_title('2D Points') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.bar(info.dim_idx_arrays[0], data, alpha=0.7) + ax.set_title('2D Bar') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ, ℤ) -> ℝ +def plot_curves_2d( + data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + for category in range(data.shape[1]): + ax.plot(data[:, 0], data[:, 1]) + + ax.set_title('2D Curves') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.legend() + + +def plot_filled_curves_2d( + data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1]) + ax.set_title('2D Curves') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ, ℝ) -> ℝ +def plot_heatmap_2d( + data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.dim_names[1] + y_unit = info.dim_units[y_name] + + heatmap = ax.imshow(data, aspect='auto', interpolation='none') + ax.figure.colorbar(heatmap, ax=ax) + ax.set_title('Heatmap') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))