From a3defd3c1ca045a2991e1235feee5daa674da710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Tue, 23 Apr 2024 19:27:45 +0200 Subject: [PATCH] feat: Complete matplotlib plotting system. The Viz node now detects the shape of the data, and presents compatible plot options. Not all are implemented, but a few quite important ones are. Additionally, a number of dataflow-related bugs were investigated and fixed. A few were truly damaging, but many simply resulted in gross inefficiencies - we must be careful declaring BLFields that are updated in hot loops! Moreover, it is exceptionally easy to add more as needed, as we analyze more and more sims. The only limit is `matplotlib`, which is... well, yeah. Due to the BLField work, the dynamicness of the Viz node is quite under control, so there will not be any critical issues there. The plotting lags (70ms total in the hot loop), but that's actually entirely fixeable. It's also entirely the `managed_bl_image`'s fault. Fixing these inefficiencies will also make Tidy3D's builtin plots near-realtime, incidentally. We profiled the following currently: - 25ms: Creating `fig = plt.subplots`. We can reuse fig per-managed image. - 43ms: The BytesIO roundtrip, including `savefig`. We can instead use the Agg backend, `fig.canvas.draw()`, and a `np.frombuffer` to both plot directly to the memory location, - ~3ms: Actual plotting functions in `image_ops`. They are seriously fast. - ~0ms: Blitting pixels to the Blender image - this was optimized in 4.1, and it shows; the time to copy the data over is essentially nothing. --- .../maxwell_sim_nodes/contracts/flow_kinds.py | 54 +++- .../managed_objs/managed_bl_image.py | 37 ++- .../nodes/analysis/extract_data.py | 113 ++++++-- .../nodes/analysis/math/filter_math.py | 55 ++-- .../nodes/analysis/math/map_math.py | 32 ++- .../maxwell_sim_nodes/nodes/analysis/viz.py | 258 +++++++++++++++++- .../maxwell_sim_nodes/nodes/base.py | 8 +- .../maxwell_sim_nodes/sockets/base.py | 8 +- .../maxwell_sim_nodes/sockets/basic/data.py | 8 +- src/blender_maxwell/utils/bl_cache.py | 3 +- .../utils/extra_sympy_units.py | 50 ++++ src/blender_maxwell/utils/image_ops.py | 128 +++++++++ 12 files changed, 666 insertions(+), 88 deletions(-) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py index 945516e..4bd1c5d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py @@ -126,6 +126,10 @@ class ArrayFlow: def __len__(self) -> int: return len(self.values) + @functools.cached_property + def mathtype(self) -> spux.MathType: + return spux.MathType.from_pytype(type(self.values.item(0))) + def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: """Find the index of the value that is closest to the given value. @@ -437,6 +441,27 @@ class LazyArrayRangeFlow: key=lambda sym: sym.name, ) + @functools.cached_property + def mathtype(self) -> spux.MathType: + # Get Start Mathtype + if isinstance(self.start, spux.SympyType): + start_mathtype = spux.MathType.from_expr(self.start) + else: + start_mathtype = spux.MathType.from_pytype(self.start) + + # Get Stop Mathtype + if isinstance(self.stop, spux.SympyType): + stop_mathtype = spux.MathType.from_expr(type(self.stop)) + else: + stop_mathtype = spux.MathType.from_pytype(type(self.stop)) + + # Check Equal + if start_mathtype != stop_mathtype: + msg = "Mathtypes of start and stop don't agree. Please fix!" + raise ValueError(msg) + + return start_mathtype + def __len__(self): return self.steps @@ -688,4 +713,31 @@ class InfoFlow: default_factory=dict ) ## TODO: Rename to dim_idxs - ## TODO: Validation, esp. length of dims. Pydantic? + @functools.cached_property + def dim_lens(self) -> dict[str, int]: + return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} + + @functools.cached_property + def dim_mathtypes(self) -> dict[str, int]: + return { + dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() + } + + @functools.cached_property + def dim_units(self) -> dict[str, int]: + return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} + + @functools.cached_property + def dim_idx_arrays(self) -> list[ArrayFlow]: + return [ + dim_idx.realize().values + if isinstance(dim_idx, LazyArrayRangeFlow) + else dim_idx.values + for dim_idx in self.dim_idx.values() + ] + return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} + + # Output Information + output_names: list[str] = dataclasses.field(default_factory=list) + output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict) + output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index b90fd69..c686d65 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -1,4 +1,5 @@ import io +import time import typing as typ import bpy @@ -205,61 +206,57 @@ class ManagedBLImage(base.ManagedObj): dpi: int | None = None, bl_select: bool = False, ): - # time_start = time.perf_counter() + times = [time.perf_counter()] import matplotlib.pyplot as plt - # log.debug('Imported PyPlot (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Compute Plot Dimensions aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = ( self.gen_image_geometry(width_inches, height_inches, dpi) ) - # log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) - # log.debug( - # 'Creating MPL Axes (aspect=%f, width=%f, height=%f)', - # aspect_ratio, - # _width_inches, - # _height_inches, - # ) # Create MPL Figure, Axes, and Compute Figure Geometry fig, ax = plt.subplots( figsize=[_width_inches, _height_inches], dpi=_dpi, ) - # log.debug('Created MPL Axes (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) ax.set_aspect(aspect_ratio) + times.append(time.perf_counter() - times[0]) cmp_width_px, cmp_height_px = fig.canvas.get_width_height() - ## Use computed pixel w/h to preempt off-by-one size errors. + times.append(time.perf_counter() - times[0]) ax.set_aspect('auto') ## Workaround aspect-ratio bugs - # log.debug('Set MPL Aspect (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Plot w/User Parameter func_plotter(ax) - # log.debug('User Plot Function (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) # Save Figure to BytesIO with io.BytesIO() as buff: - # log.debug('Made BytesIO (%f)', time.perf_counter() - time_start) fig.savefig(buff, format='raw', dpi=dpi) - # log.debug('Saved Figure to BytesIO (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) buff.seek(0) image_data = np.frombuffer( buff.getvalue(), dtype=np.uint8, ).reshape([cmp_height_px, cmp_width_px, -1]) - # log.debug('Set Image Data (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) image_data = np.flipud(image_data).astype(np.float32) / 255 - # log.debug('Flipped Image Data (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) plt.close(fig) # Optimized Write to Blender Image bl_image = self.bl_image(cmp_width_px, cmp_height_px, 'RGBA', 'uint8') - # log.debug('Made BL Image (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) bl_image.pixels.foreach_set(image_data.ravel()) - # log.debug('Set BL Image Pixels (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) bl_image.update() - # log.debug('Updated BL Image (%f)', time.perf_counter() - time_start) + times.append(time.perf_counter() - times[0]) if bl_select: self.bl_select() + times.append(time.perf_counter() - times[0]) + #log.critical('Timing of MPL Plot: %s', str(times)) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index eea08fa..bbc428a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -1,12 +1,13 @@ +import enum import typing as typ -import enum import bpy import jax import jax.numpy as jnp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct from ... import sockets @@ -45,11 +46,13 @@ class ExtractDataNode(base.MaxwellSimNode): ) # Sim Data - sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField({}) + sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField( + {}, use_prop_update=False + ) # Monitor Data - monitor_data_type: str = bl_cache.BLField('') - monitor_data_components: list[str] = bl_cache.BLField([]) + monitor_data_type: str = bl_cache.BLField('', use_prop_update=False) + monitor_data_components: list[str] = bl_cache.BLField([], use_prop_update=False) #################### # - Computed Properties @@ -177,7 +180,7 @@ class ExtractDataNode(base.MaxwellSimNode): self.extract_filter = bl_cache.Signal.ResetEnumItems #################### - # - Output: Value + # - Output: Sim Data -> Monitor Data #################### @events.computes_output_socket( 'Monitor Data', @@ -186,34 +189,52 @@ class ExtractDataNode(base.MaxwellSimNode): input_sockets={'Sim Data'}, ) def compute_monitor_data(self, props: dict, input_sockets: dict): - return input_sockets['Sim Data'].monitor_data[props['extract_filter']] + if input_sockets['Sim Data'] is not None and props['extract_filter'] != 'NONE': + return input_sockets['Sim Data'].monitor_data[props['extract_filter']] + return None + + #################### + # - Output: Monitor Data -> Data + #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.Array, props={'extract_filter'}, input_sockets={'Monitor Data'}, + input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, ) - def compute_data(self, props: dict, input_sockets: dict) -> jax.Array: - xarray_data = getattr(input_sockets['Monitor Data'], props['extract_filter']) - return jnp.array(xarray_data.data) ## TODO: Can it be done without a copy? + def compute_data(self, props: dict, input_sockets: dict) -> jax.Array | None: + if ( + input_sockets['Monitor Data'] is not None + and props['extract_filter'] != 'NONE' + ): + xarray_data = getattr( + input_sockets['Monitor Data'], props['extract_filter'] + ) + return jnp.array(xarray_data.data) + ## TODO: Let the array itself have its output unit too! + + return None - #################### - # - Output: LazyValueFunc - #################### @events.computes_output_socket( 'Data', kind=ct.FlowKind.LazyValueFunc, output_sockets={'Data'}, output_socket_kinds={'Data': ct.FlowKind.Array}, ) - def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.LazyValueFuncFlow: - return ct.LazyValueFuncFlow( - func=lambda: output_sockets['Data'], supports_jax=True - ) + def compute_extracted_data_lazy( + self, output_sockets: dict + ) -> ct.LazyValueFuncFlow | None: + if output_sockets['Data'] is not None: + return ct.LazyValueFuncFlow( + func=lambda: output_sockets['Data'], supports_jax=True + ) + + return None #################### - # - Output: Info + # - Auxiliary: Monitor Data -> Data #################### @events.computes_output_socket( 'Data', @@ -221,6 +242,7 @@ class ExtractDataNode(base.MaxwellSimNode): props={'monitor_data_type', 'extract_filter'}, input_sockets={'Monitor Data'}, input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, + input_sockets_optional={'Monitor Data': True}, ) def compute_extracted_data_info( self, props: dict, input_sockets: dict @@ -234,12 +256,16 @@ class ExtractDataNode(base.MaxwellSimNode): else: return ct.InfoFlow() + info_output_names = { + 'output_names': [props['extract_filter']], + } + # Compute InfoFlow from XArray ## XYZF: Field / Permittivity / FieldProjectionCartesian if props['monitor_data_type'] in { 'Field', 'Permittivity', - 'FieldProjectionCartesian', + #'FieldProjectionCartesian', }: return ct.InfoFlow( dim_names=['x', 'y', 'z', 'f'], @@ -256,6 +282,13 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Complex}, + output_units={ + props['extract_filter']: spu.volt / spu.micrometer + if props['monitor_data_type'] == 'Field' + else None + }, ) ## XYZT: FieldTime @@ -275,6 +308,17 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Complex}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + if props['monitor_data_type'] == 'Field' + else None + }, ) ## F: Flux @@ -288,6 +332,9 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={props['extract_filter']: spu.watt}, ) ## T: FluxTime @@ -301,6 +348,9 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={props['extract_filter']: spu.watt}, ) ## RThetaPhiF: FieldProjectionAngle @@ -327,6 +377,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) ## UxUyRF: FieldProjectionKSpace @@ -351,6 +410,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) ## OrderxOrderyF: Diffraction @@ -372,6 +440,15 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, + **info_output_names, + output_mathtypes={props['extract_filter']: spux.MathType.Real}, + output_units={ + props['extract_filter']: ( + spu.volt / spu.micrometer + if props['extract_filter'].startswith('E') + else spu.ampere / spu.micrometer + ) + }, ) msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 6c5fac0..18e6a98 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -47,8 +47,9 @@ class FilterMathNode(base.MaxwellSimNode): None, prop_ui=True, enum_cb=lambda self, _: self.search_dims() ) - dim_names: list[str] = bl_cache.BLField([]) - dim_lens: dict[str, int] = bl_cache.BLField({}) + @property + def _info(self) -> ct.InfoFlow: + return self._compute_input('Data', kind=ct.FlowKind.Info) #################### # - Operation Search @@ -70,16 +71,16 @@ class FilterMathNode(base.MaxwellSimNode): # - Dim Search #################### def search_dims(self) -> list[ct.BLEnumElement]: - if self.dim_names: + if (info := self._info).dim_names: dims = [ (dim_name, dim_name, dim_name, '', i) - for i, dim_name in enumerate(self.dim_names) + for i, dim_name in enumerate(info.dim_names) ] # Squeeze: Dimension Must Have Length=1 ## We must also correct the "NUMBER" of the enum. if self.operation == 'SQUEEZE': - filtered_dims = [dim for dim in dims if self.dim_lens[dim[0]] == 1] + filtered_dims = [dim for dim in dims if info.dim_lens[dim[0]] == 1] return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)] return dims @@ -90,7 +91,7 @@ class FilterMathNode(base.MaxwellSimNode): #################### def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') - if self.dim_names: + if self._info.dim_names: layout.prop(self, self.blfields['dim'], text='') #################### @@ -98,52 +99,49 @@ class FilterMathNode(base.MaxwellSimNode): #################### @events.on_value_changed( prop_name='active_socket_set', + run_on_init=True, ) def on_socket_set_changed(self): self.operation = bl_cache.Signal.ResetEnumItems @events.on_value_changed( - socket_name={'Data'}, - prop_name={'active_socket_set'}, + socket_name='Data', + prop_name='active_socket_set', props={'active_socket_set'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, - input_sockets_optional={'Data': True}, - run_on_init=True, + # run_on_init=True, ) def on_any_change(self, props: dict, input_sockets: dict): - # Set Dimension Names from InfoFlow - if input_sockets['Data'].dim_names: - self.dim_names = input_sockets['Data'].dim_names - self.dim_lens = { - dim_name: len(dim_idx) - for dim_name, dim_idx in input_sockets['Data'].dim_idx.items() - } - else: - self.dim_names = [] - self.dim_lens = {} - - # Reset Enum self.dim = bl_cache.Signal.ResetEnumItems @events.on_value_changed( + socket_name='Data', prop_name='dim', props={'active_socket_set', 'dim'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, - input_sockets_optional={'Data': True}, + # run_on_init=True, ) def on_dim_change(self, props: dict, input_sockets: dict): # Add/Remove Input Socket "Value" - if props['active_socket_set'] == 'By Dim Value' and props['dim'] != 'NONE': + if ( + input_sockets['Data'] != ct.InfoFlow() + and props['active_socket_set'] == 'By Dim Value' + and props['dim'] != 'NONE' + ): # Get Current and Wanted Socket Defs - current_socket_def = self.loose_input_sockets.get('Value') + current_bl_socket = self.loose_input_sockets.get('Value') wanted_socket_def = sockets.SOCKET_DEFS[ ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit) ] # Determine Whether to Declare New Loose Input SOcket - if current_socket_def is None or current_socket_def != wanted_socket_def: + if ( + current_bl_socket is None + or sockets.SOCKET_DEFS[current_bl_socket.socket_type] + != wanted_socket_def + ): self.loose_input_sockets = { 'Value': wanted_socket_def(), } @@ -225,7 +223,7 @@ class FilterMathNode(base.MaxwellSimNode): # Compute Bound/Free Parameters ## Empty Dimension -> Empty InfoFlow - if props['dim'] != 'NONE': + if input_sockets['Data'] != ct.InfoFlow() and props['dim'] != 'NONE': axis = info.dim_names.index(props['dim']) else: return ct.InfoFlow() @@ -243,6 +241,9 @@ class FilterMathNode(base.MaxwellSimNode): for dim_name, dim_idx in info.dim_idx.items() if dim_name != props['dim'] }, + output_names=info.output_names, + output_mathtypes=info.output_mathtypes, + output_units=info.output_units, ) # Fallback to Empty InfoFlow diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 731805b..57b250b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -6,7 +6,8 @@ import jax import jax.numpy as jnp import sympy as sp -from blender_maxwell.utils import logger, bl_cache +from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct from .... import sockets @@ -143,7 +144,7 @@ class MapMathNode(base.MaxwellSimNode): 'SINC': lambda data: jnp.sinc(data), }, 'By Vector': { - 'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1), + 'NORM_2': lambda data: jnp.linalg.norm(data, ord=2, axis=-1), }, 'By Matrix': { # Matrix -> Number @@ -196,11 +197,34 @@ class MapMathNode(base.MaxwellSimNode): @events.computes_output_socket( 'Data', kind=ct.FlowKind.Info, + props={'active_socket_set', 'operation'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Info}, ) - def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow: - return input_sockets['Data'] + def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: + info = input_sockets['Data'] + + # Complex -> Real + if props['active_socket_set'] == 'By Element' and props['operation'] in [ + 'REAL', + 'IMAG', + 'ABS', + ]: + return ct.InfoFlow( + dim_names=info.dim_names, + dim_idx=info.dim_idx, + output_names=info.output_names, + output_mathtypes={ + output_name: ( + spux.MathType.Real + if output_mathtype == spux.MathType.Complex + else output_mathtype + ) + for output_name, output_mathtype in info.output_mathtypes.items() + }, + output_units=info.output_units, + ) + return info @events.computes_output_socket( 'Data', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 21d65f6..469ae56 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -2,8 +2,11 @@ import enum import typing as typ import bpy +import jaxtyping as jtyp +import matplotlib.axis as mpl_ax from blender_maxwell.utils import bl_cache, image_ops, logger +from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct from ... import managed_objs, sockets @@ -12,9 +15,168 @@ from .. import base, events log = logger.get(__name__) +class VizMode(enum.StrEnum): + """Available visualization modes. + + **NOTE**: >1D output dimensions currently have no viz. + + Plots for `() -> ℝ`: + - Hist1D: Bin-summed distribution. + - BoxPlot1D: Box-plot describing the distribution. + + Plots for `(ℤ) -> ℝ`: + - BoxPlots1D: Side-by-side boxplots w/equal y axis. + + Plots for `(ℝ) -> ℝ`: + - Curve2D: Standard line-curve w/smooth interpolation + - Points2D: Scatterplot of individual points. + - Bar: Value to height of a barplot. + + Plots for `(ℝ, ℤ) -> ℝ`: + - Curves2D: Layered Curve2Ds with unique colors. + - FilledCurves2D: Layered Curve2Ds with filled space between. + + Plots for `(ℝ, ℝ) -> ℝ`: + - Heatmap2D: Colormapped image with value at each pixel. + + Plots for `(ℝ, ℝ, ℝ) -> ℝ`: + - SqueezedHeatmap2D: 3D-embeddable heatmap for when one of the axes is 1. + - Heatmap3D: Colormapped field with value at each voxel. + """ + + Hist1D = enum.auto() + BoxPlot1D = enum.auto() + + Curve2D = enum.auto() + Points2D = enum.auto() + Bar = enum.auto() + + Curves2D = enum.auto() + FilledCurves2D = enum.auto() + + Heatmap2D = enum.auto() + + SqueezedHeatmap2D = enum.auto() + Heatmap3D = enum.auto() + + @staticmethod + def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: + valid_viz_modes = { + ((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D], + ((spux.MathType.Integer), (spux.MathType.Real)): [ + VizMode.Hist1D, + VizMode.BoxPlot1D, + ], + ((spux.MathType.Real,), (spux.MathType.Real,)): [ + VizMode.Curve2D, + VizMode.Points2D, + VizMode.Bar, + ], + ((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [ + VizMode.Curves2D, + VizMode.FilledCurves2D, + ], + ((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [ + VizMode.Heatmap2D, + ], + ( + (spux.MathType.Real, spux.MathType.Real, spux.MathType.Real), + (spux.MathType.Real,), + ): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D], + }.get( + ( + tuple(info.dim_mathtypes.values()), + tuple(info.output_mathtypes.values()), + ) + ) + + if valid_viz_modes is None: + return [] + + return valid_viz_modes + + @staticmethod + def to_plotter( + value: typ.Self, + ) -> typ.Callable[ + [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None + ]: + return { + VizMode.Hist1D: image_ops.plot_hist_1d, + VizMode.BoxPlot1D: image_ops.plot_box_plot_1d, + VizMode.Curve2D: image_ops.plot_curve_2d, + VizMode.Points2D: image_ops.plot_points_2d, + VizMode.Bar: image_ops.plot_bar, + VizMode.Curves2D: image_ops.plot_curves_2d, + VizMode.FilledCurves2D: image_ops.plot_filled_curves_2d, + VizMode.Heatmap2D: image_ops.plot_heatmap_2d, + # NO PLOTTER: VizMode.SqueezedHeatmap2D + # NO PLOTTER: VizMode.Heatmap3D + }[value] + + @staticmethod + def to_name(value: typ.Self) -> str: + return { + VizMode.Hist1D: 'Histogram', + VizMode.BoxPlot1D: 'Box Plot', + VizMode.Curve2D: 'Curve', + VizMode.Points2D: 'Points', + VizMode.Bar: 'Bar', + VizMode.Curves2D: 'Curves', + VizMode.FilledCurves2D: 'Filled Curves', + VizMode.Heatmap2D: 'Heatmap', + VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)', + VizMode.Heatmap3D: 'Heatmap (3D)', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> ct.BLIcon: + return '' + + +class VizTarget(enum.StrEnum): + """Available visualization targets.""" + + Plot2D = enum.auto() + Pixels = enum.auto() + PixelsPlane = enum.auto() + Voxels = enum.auto() + + @staticmethod + def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: + return { + 'NONE': [], + VizMode.Hist1D: [VizTarget.Plot2D], + VizMode.BoxPlot1D: [VizTarget.Plot2D], + VizMode.Curve2D: [VizTarget.Plot2D], + VizMode.Points2D: [VizTarget.Plot2D], + VizMode.Bar: [VizTarget.Plot2D], + VizMode.Curves2D: [VizTarget.Plot2D], + VizMode.FilledCurves2D: [VizTarget.Plot2D], + VizMode.Heatmap2D: [VizTarget.Plot2D, VizTarget.Pixels], + VizMode.SqueezedHeatmap2D: [VizTarget.Pixels, VizTarget.PixelsPlane], + VizMode.Heatmap3D: [VizTarget.Voxels], + }[viz_mode] + + @staticmethod + def to_name(value: typ.Self) -> str: + return { + VizTarget.Plot2D: 'Image (Plot)', + VizTarget.Pixels: 'Image (Pixels)', + VizTarget.PixelsPlane: 'Image (Plane)', + VizTarget.Voxels: '3D Field', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> ct.BLIcon: + return '' + + class VizNode(base.MaxwellSimNode): """Node for visualizing simulation data, by querying its monitors. + Auto-detects the correct plot type based on the input data: + Attributes: colormap: Colormap to apply to 0..1 output. @@ -40,24 +202,91 @@ class VizNode(base.MaxwellSimNode): ##################### ## - Properties ##################### + viz_mode: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_modes() + ) + viz_target: enum.Enum = bl_cache.BLField( + prop_ui=True, enum_cb=lambda self, _: self.search_targets() + ) + + # Mode-Dependent Properties colormap: image_ops.Colormap = bl_cache.BLField( image_ops.Colormap.Viridis, prop_ui=True ) + ##################### + ## - Mode Searcher + ##################### + @property + def _info(self) -> ct.InfoFlow: + return self._compute_input('Data', kind=ct.FlowKind.Info) + + def search_modes(self) -> list[ct.BLEnumElement]: + info = self._info + return [ + ( + viz_mode, + VizMode.to_name(viz_mode), + VizMode.to_name(viz_mode), + VizMode.to_icon(viz_mode), + i, + ) + for i, viz_mode in enumerate(VizMode.valid_modes_for(info)) + ] + + ##################### + ## - Target Searcher + ##################### + def search_targets(self) -> list[ct.BLEnumElement]: + return [ + ( + viz_target, + VizTarget.to_name(viz_target), + VizTarget.to_name(viz_target), + VizTarget.to_icon(viz_target), + i, + ) + for i, viz_target in enumerate(VizTarget.valid_targets_for(self.viz_mode)) + ] + ##################### ## - UI ##################### def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): - col.prop(self, self.blfields['colormap'], text='') + col.prop(self, self.blfields['viz_mode'], text='') + col.prop(self, self.blfields['viz_target'], text='') + if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]: + col.prop(self, self.blfields['colormap'], text='') + + #################### + # - Events + #################### + @events.on_value_changed( + socket_name='Data', + input_sockets={'Data'}, + input_socket_kinds={'Data': ct.FlowKind.Info}, + input_sockets_optional={'Data': True}, + run_on_init=True, + ) + def on_socket_set_changed(self, input_sockets: dict): + self.viz_mode = bl_cache.Signal.ResetEnumItems + self.viz_target = bl_cache.Signal.ResetEnumItems + + @events.on_value_changed( + prop_name='viz_mode', + # run_on_init=True, + ) + def on_viz_mode_changed(self): + self.viz_target = bl_cache.Signal.ResetEnumItems ##################### ## - Plotting ##################### @events.on_show_plot( managed_objs={'plot'}, - props={'colormap'}, + props={'viz_mode', 'viz_target', 'colormap'}, input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Array}, + input_socket_kinds={'Data': {ct.FlowKind.Array, ct.FlowKind.Info}}, input_sockets_optional={'Data': True}, stop_propagation=True, ) @@ -67,13 +296,32 @@ class VizNode(base.MaxwellSimNode): input_sockets: dict, props: dict, ): - if input_sockets['Data'] is not None: + array_flow = input_sockets['Data'][ct.FlowKind.Array] + info = input_sockets['Data'][ct.FlowKind.Info] + + if input_sockets['Data'] is None: + return + + if props['viz_target'] == VizTarget.Plot2D: + managed_objs['plot'].mpl_plot_to_image( + lambda ax: VizMode.to_plotter(props['viz_mode'])( + array_flow.values, info, ax + ), + bl_select=True, + ) + if props['viz_target'] == VizTarget.Pixels: managed_objs['plot'].map_2d_to_image( - input_sockets['Data'].values, + array_flow.values, colormap=props['colormap'], bl_select=True, ) + if props['viz_target'] == VizTarget.PixelsPlane: + raise NotImplementedError + + if props['viz_target'] == VizTarget.Voxels: + raise NotImplementedError + #################### # - Blender Registration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index 85f0afc..3029b1a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -843,10 +843,12 @@ class MaxwellSimNode(bpy.types.Node): # Propagate Event to All Sockets in "Trigger Direction" ## The trigger chain goes node/socket/node/socket/... if not stop_propagation: - triggered_sockets = self._bl_sockets( - direc=ct.FlowEvent.flow_direction[event] - ) + direc = ct.FlowEvent.flow_direction[event] + triggered_sockets = self._bl_sockets(direc=direc) for bl_socket in triggered_sockets: + if direc == 'output' and not bl_socket.is_linked: + continue + # log.critical( # '![%s] Propagating: (%s, %s)', # self.sim_node_name, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index b58e66b..73785db 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -505,10 +505,10 @@ class MaxwellSimSocket(bpy.types.NodeSocket): socket_name=self.name, socket_kinds=socket_kinds, ) - - self.node.trigger_event( - event, socket_name=self.name, socket_kinds=socket_kinds - ) + else: + self.node.trigger_event( + event, socket_name=self.name, socket_kinds=socket_kinds + ) # Output Socket | Input Flow if self.is_output and flow_direction == 'input': diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py index 1ed0d08..fcf55a8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/data.py @@ -49,14 +49,14 @@ class DataBLSocket(base.MaxwellSimSocket): columns=3, row_major=True, even_columns=True, - #even_rows=True, + # even_rows=True, align=True, ) # Grid Header - #grid.label(text='Dim') - #grid.label(text='Len') - #grid.label(text='Unit') + # grid.label(text='Dim') + # grid.label(text='Len') + # grid.label(text='Unit') # Dimension Names for dim_name in info.dim_names: diff --git a/src/blender_maxwell/utils/bl_cache.py b/src/blender_maxwell/utils/bl_cache.py index 344daaf..74ceaa5 100644 --- a/src/blender_maxwell/utils/bl_cache.py +++ b/src/blender_maxwell/utils/bl_cache.py @@ -828,7 +828,6 @@ class BLField: current_items = self._enum_cb(bl_instance, None) # Only Change if Changes Need Making - ## if old_items != current_items: # Set Enum to First Item ## Prevents the seemingly "missing" enum element bug. @@ -837,7 +836,7 @@ class BLField: ## -> Infinite recursion if we don't check current value. ## -> May cause a hiccup (chains will trigger twice) ## To work, there **must** be a guaranteed-available string at 0,0. - first_old_value = old_items(bl_instance, None)[0][0] + first_old_value = old_items[0][0] current_value = self._cached_bl_property.__get__( bl_instance, bl_instance.__class__ ) diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index 28f6736..14ad141 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -10,9 +10,11 @@ Attributes: Should be used via the `ConstrSympyExpr`, which also adds expression validation. """ +import enum import itertools import typing as typ +import jax.numpy as jnp import pydantic as pyd import sympy as sp import sympy.physics.units as spu @@ -22,6 +24,54 @@ from pydantic_core import core_schema as pyd_core_schema SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix | spu.Quantity +class MathType(enum.StrEnum): + Bool = enum.auto() + Integer = enum.auto() + Rational = enum.auto() + Real = enum.auto() + Complex = enum.auto() + + @staticmethod + def from_expr(sp_obj: SympyType) -> type: + if isinstance(sp_obj, sp.logic.boolalg.Boolean): + return MathType.Bool + if sp_obj.is_integer: + return MathType.Integer + if sp_obj.is_rational or sp_obj.is_real: + return MathType.Real + if sp_obj.is_complex: + return MathType.Complex + + msg = "Can't determine MathType from sympy object: {sp_obj}" + raise ValueError(msg) + + @staticmethod + def from_pytype(dtype) -> type: + return { + bool: MathType.Bool, + int: MathType.Integer, + float: MathType.Real, + complex: MathType.Complex, + #jnp.int32: MathType.Integer, + #jnp.int64: MathType.Integer, + #jnp.float32: MathType.Real, + #jnp.float64: MathType.Real, + #jnp.complex64: MathType.Complex, + #jnp.complex128: MathType.Complex, + #jnp.bool_: MathType.Bool, + }[dtype] + + @staticmethod + def to_dtype(value: typ.Self) -> type: + return { + MathType.Bool: bool, + MathType.Integer: int, + MathType.Rational: float, + MathType.Real: float, + MathType.Complex: complex, + }[value] + + #################### # - Units #################### diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 260aec1..6f64ada 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -1,12 +1,14 @@ """Useful image processing operations for use in the addon.""" import enum +import time import typing as typ import jax import jax.numpy as jnp import jaxtyping as jtyp import matplotlib +import matplotlib.axis as mpl_ax from blender_maxwell import contracts as ct from blender_maxwell.utils import logger @@ -106,3 +108,129 @@ def rgba_image_from_2d_map( return rgba_image_from_2d_map__grayscale(map_2d) return rgba_image_from_2d_map__grayscale(map_2d) + + +#################### +# - Plotters +#################### +# () -> ℝ +def plot_hist_1d( + data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis +) -> None: + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.hist(data, bins=30, alpha=0.75) + ax.set_title('Histogram') + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℤ) -> ℝ +def plot_box_plot_1d( + data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.boxplot(data) + ax.set_title('Box Plot') + ax.set_xlabel(f'{x_name}') + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ) -> ℝ +def plot_curve_2d( + data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis +) -> None: + times = [time.perf_counter()] + + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + times.append(time.perf_counter() - times[0]) + ax.plot(info.dim_idx_arrays[0], data) + times.append(time.perf_counter() - times[0]) + ax.set_title('2D Curve') + times.append(time.perf_counter() - times[0]) + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + times.append(time.perf_counter() - times[0]) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + times.append(time.perf_counter() - times[0]) + # log.critical('Timing of Curve2D: %s', str(times)) + + +def plot_points_2d( + data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) + ax.set_title('2D Points') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.bar(info.dim_idx_arrays[0], data, alpha=0.7) + ax.set_title('2D Bar') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ, ℤ) -> ℝ +def plot_curves_2d( + data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + for category in range(data.shape[1]): + ax.plot(data[:, 0], data[:, 1]) + + ax.set_title('2D Curves') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.legend() + + +def plot_filled_curves_2d( + data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.output_names[0] + y_unit = info.output_units[y_name] + + ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1]) + ax.set_title('2D Curves') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + + +# (ℝ, ℝ) -> ℝ +def plot_heatmap_2d( + data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis +) -> None: + x_name = info.dim_names[0] + x_unit = info.dim_units[x_name] + y_name = info.dim_names[1] + y_unit = info.dim_units[y_name] + + heatmap = ax.imshow(data, aspect='auto', interpolation='none') + ax.figure.colorbar(heatmap, ax=ax) + ax.set_title('Heatmap') + ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) + ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))