From e7d3ecf48e0eea7ae8eff8f49b33e0491eb6be68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Tue, 23 Apr 2024 11:51:24 +0200 Subject: [PATCH] fix: Crashes on enum changes --- .../managed_objs/managed_bl_image.py | 65 +---------- .../nodes/analysis/math/filter_math.py | 2 +- .../maxwell_sim_nodes/nodes/analysis/viz.py | 37 +++--- src/blender_maxwell/nodeps/utils/image_ops.py | 0 src/blender_maxwell/utils/bl_cache.py | 75 ++++++------ src/blender_maxwell/utils/image_ops.py | 108 ++++++++++++++++++ 6 files changed, 173 insertions(+), 114 deletions(-) delete mode 100644 src/blender_maxwell/nodeps/utils/image_ops.py create mode 100644 src/blender_maxwell/utils/image_ops.py diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index 1364ccf..b90fd69 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -2,13 +2,10 @@ import io import typing as typ import bpy -import jax -import jax.numpy as jnp -import matplotlib import matplotlib.axis as mpl_ax import numpy as np -from blender_maxwell.utils import logger +from blender_maxwell.utils import image_ops, logger from .. import contracts as ct from . import base @@ -18,64 +15,6 @@ log = logger.get(__name__) AREA_TYPE = 'IMAGE_EDITOR' SPACE_TYPE = 'IMAGE_EDITOR' -# Colormap -_MPL_CM = matplotlib.cm.get_cmap('viridis', 512) -VIRIDIS_COLORMAP = jnp.array([_MPL_CM(i)[:3] for i in range(512)]) - - -#################### -# - Image Functions -#################### -def apply_colormap(normalized_data, colormap): - # Linear interpolation between colormap points - n_colors = colormap.shape[0] - indices = normalized_data * (n_colors - 1) - lower_idx = jnp.floor(indices).astype(jnp.int32) - upper_idx = jnp.ceil(indices).astype(jnp.int32) - alpha = indices - lower_idx - - lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx) - upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx) - - return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors - - -@jax.jit -def rgba_image_from_2d_map__viridis(map_2d): - amplitude = jnp.abs(map_2d) - amplitude_normalized = (amplitude - amplitude.min()) / ( - amplitude.max() - amplitude.min() - ) - rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP) - alpha_channel = jnp.ones_like(amplitude_normalized) - return jnp.dstack((rgb_array, alpha_channel)) - - -@jax.jit -def rgba_image_from_2d_map__grayscale(map_2d): - amplitude = jnp.abs(map_2d) - amplitude_normalized = (amplitude - amplitude.min()) / ( - amplitude.max() - amplitude.min() - ) - rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1) - alpha_channel = jnp.ones_like(amplitude_normalized) - return jnp.dstack((rgb_array, alpha_channel)) - - -def rgba_image_from_2d_map(map_2d, colormap: str | None = None): - """RGBA Image from a map of 2D coordinates to values. - - Parameters: - map_2d: Shape (width, height, value). - - Returns: - Image as a JAX array of shape (height, width, 4) - """ - if colormap == 'VIRIDIS': - return rgba_image_from_2d_map__viridis(map_2d) - if colormap == 'GRAYSCALE': - return rgba_image_from_2d_map__grayscale(map_2d) - #################### # - Managed BL Image @@ -235,7 +174,7 @@ class ManagedBLImage(base.ManagedObj): self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False ): self.data_to_image( - lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap), + lambda _: image_ops.rgba_image_from_2d_map(map_2d, colormap=colormap), bl_select=bl_select, ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 66a21be..6c5fac0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -123,7 +123,7 @@ class FilterMathNode(base.MaxwellSimNode): self.dim_names = [] self.dim_lens = {} - # Reset String Searcher + # Reset Enum self.dim = bl_cache.Signal.ResetEnumItems @events.on_value_changed( diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 4329019..21d65f6 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -1,8 +1,9 @@ +import enum import typing as typ import bpy -from blender_maxwell.utils import logger +from blender_maxwell.utils import bl_cache, image_ops, logger from ... import contracts as ct from ... import managed_objs, sockets @@ -12,7 +13,12 @@ log = logger.get(__name__) class VizNode(base.MaxwellSimNode): - """Node for visualizing simulation data, by querying its monitors.""" + """Node for visualizing simulation data, by querying its monitors. + + Attributes: + colormap: Colormap to apply to 0..1 output. + + """ node_type = ct.NodeType.Viz bl_label = 'Viz' @@ -34,31 +40,25 @@ class VizNode(base.MaxwellSimNode): ##################### ## - Properties ##################### - colormap: bpy.props.EnumProperty( - name='Colormap', - description='Colormap to apply to grayscale output', - items=[ - ('VIRIDIS', 'Viridis', 'Good default colormap'), - ('GRAYSCALE', 'Grayscale', 'Barebones'), - ], - default='VIRIDIS', - update=lambda self, context: self.on_prop_changed('colormap', context), + colormap: image_ops.Colormap = bl_cache.BLField( + image_ops.Colormap.Viridis, prop_ui=True ) ##################### ## - UI ##################### def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): - col.prop(self, 'colormap') + col.prop(self, self.blfields['colormap'], text='') ##################### ## - Plotting ##################### @events.on_show_plot( managed_objs={'plot'}, + props={'colormap'}, input_sockets={'Data'}, input_socket_kinds={'Data': ct.FlowKind.Array}, - props={'colormap'}, + input_sockets_optional={'Data': True}, stop_propagation=True, ) def on_show_plot( @@ -67,11 +67,12 @@ class VizNode(base.MaxwellSimNode): input_sockets: dict, props: dict, ): - managed_objs['plot'].map_2d_to_image( - input_sockets['Data'].values, - colormap=props['colormap'], - bl_select=True, - ) + if input_sockets['Data'] is not None: + managed_objs['plot'].map_2d_to_image( + input_sockets['Data'].values, + colormap=props['colormap'], + bl_select=True, + ) #################### diff --git a/src/blender_maxwell/nodeps/utils/image_ops.py b/src/blender_maxwell/nodeps/utils/image_ops.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/blender_maxwell/utils/bl_cache.py b/src/blender_maxwell/utils/bl_cache.py index ed5ffaa..344daaf 100644 --- a/src/blender_maxwell/utils/bl_cache.py +++ b/src/blender_maxwell/utils/bl_cache.py @@ -740,7 +740,7 @@ class BLField: str(value), AttrType.to_name(value), AttrType.to_name(value), ## TODO: From AttrType.__doc__ - AttrType.to_icon(), + AttrType.to_icon(value), i if not self._enum_many else 2**i, ) for i, value in enumerate(list(AttrType)) @@ -824,42 +824,53 @@ class BLField: def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None: if value == Signal.ResetEnumItems: - # Set Enum to First Item - ## Prevents the seemingly "missing" enum element bug. - ## -> Caused by the old int still trying to hang on after. - ## -> We can mitigate this by preemptively setting the enum. - ## -> Infinite recursion if we don't check current value. - ## -> May cause a hiccup (chains will trigger twice) - ## To work, there **must** be a guaranteed-available string at 0,0. - first_old_value = self._safe_enum_cb(bl_instance, None)[0][0] - current_value = self._cached_bl_property.__get__( - bl_instance, bl_instance.__class__ - ) - if current_value != first_old_value: - self._cached_bl_property.__set__(bl_instance, first_old_value) + old_items = self._safe_enum_cb(bl_instance, None) + current_items = self._enum_cb(bl_instance, None) - # Pop the Cached Enum Items - ## The next time Blender asks for the enum items, it'll update. - self._enum_cb_cache.pop(bl_instance.instance_id, None) + # Only Change if Changes Need Making + ## + if old_items != current_items: + # Set Enum to First Item + ## Prevents the seemingly "missing" enum element bug. + ## -> Caused by the old int still trying to hang on after. + ## -> We can mitigate this by preemptively setting the enum. + ## -> Infinite recursion if we don't check current value. + ## -> May cause a hiccup (chains will trigger twice) + ## To work, there **must** be a guaranteed-available string at 0,0. + first_old_value = old_items(bl_instance, None)[0][0] + current_value = self._cached_bl_property.__get__( + bl_instance, bl_instance.__class__ + ) + if current_value != first_old_value: + self._cached_bl_property.__set__(bl_instance, first_old_value) - # Invalidate the Getter Cache - ## The next time the user runs __get__, they'll get the new value. - self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache) + # Pop the Cached Enum Items + ## The next time Blender asks for the enum items, it'll update. + self._enum_cb_cache.pop(bl_instance.instance_id, None) + + # Invalidate the Getter Cache + ## The next time the user runs __get__, they'll get the new value. + self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache) elif value == Signal.ResetStrSearch: - # Set String to '' - ## Prevents the presence of an invalid value not in the new search. - ## -> Infinite recursion if we don't check current value for ''. - ## -> May cause a hiccup (chains will trigger twice) - current_value = self._cached_bl_property.__get__( - bl_instance, bl_instance.__class__ - ) - if current_value != '': - self._cached_bl_property.__set__(bl_instance, '') + old_items = self._safe_str_cb(bl_instance, None) + current_items = self._str_cb(bl_instance, None) - # Pop the Cached String Search Items - ## The next time Blender does a str search, it'll update. - self._str_cb_cache.pop(bl_instance.instance_id, None) + # Only Change if Changes Need Making + if old_items != current_items: + # Set String to '' + ## Prevents the presence of an invalid value not in the new search. + ## -> Infinite recursion if we don't check current value for ''. + ## -> May cause a hiccup (chains will trigger twice) + current_value = self._cached_bl_property.__get__( + bl_instance, bl_instance.__class__ + ) + if current_value != '': + self._cached_bl_property.__set__(bl_instance, '') + + # Pop the Cached String Search Items + ## The next time Blender does a str search, it'll update. + self._str_cb_cache.pop(bl_instance.instance_id, None) else: self._cached_bl_property.__set__(bl_instance, value) diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py new file mode 100644 index 0000000..260aec1 --- /dev/null +++ b/src/blender_maxwell/utils/image_ops.py @@ -0,0 +1,108 @@ +"""Useful image processing operations for use in the addon.""" + +import enum +import typing as typ + +import jax +import jax.numpy as jnp +import jaxtyping as jtyp +import matplotlib + +from blender_maxwell import contracts as ct +from blender_maxwell.utils import logger + +log = logger.get(__name__) + +#################### +# - Constants +#################### +_MPL_CM = matplotlib.cm.get_cmap('viridis', 512) +VIRIDIS_COLORMAP: jtyp.Float32[jtyp.Array, '512 3'] = jnp.array( + [_MPL_CM(i)[:3] for i in range(512)] +) + + +class Colormap(enum.StrEnum): + """Available colormaps. + + Attributes: + Viridis: Good general-purpose colormap. + Grayscale: Simple black and white mapping. + """ + + Viridis = enum.auto() + Grayscale = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + return { + Colormap.Viridis: 'Viridis', + Colormap.Grayscale: 'Grayscale', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> ct.BLIcon: + return '' + + +#################### +# - Colormap: (X,Y,1 -> Value) -> (X,Y,4 -> Value) +#################### +def apply_colormap( + normalized_data: jtyp.Float32[jtyp.Array, 'width height 4'], + colormap: jtyp.Float32[jtyp.Array, '512 3'], +): + # Linear interpolation between colormap points + n_colors = colormap.shape[0] + indices = normalized_data * (n_colors - 1) + lower_idx = jnp.floor(indices).astype(jnp.int32) + upper_idx = jnp.ceil(indices).astype(jnp.int32) + alpha = indices - lower_idx + + lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx) + upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx) + + return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors + + +@jax.jit +def rgba_image_from_2d_map__viridis(map_2d: jtyp.Float32[jtyp.Array, 'width height 4']): + amplitude = jnp.abs(map_2d) + amplitude_normalized = (amplitude - amplitude.min()) / ( + amplitude.max() - amplitude.min() + ) + rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP) + alpha_channel = jnp.ones_like(amplitude_normalized) + return jnp.dstack((rgb_array, alpha_channel)) + + +@jax.jit +def rgba_image_from_2d_map__grayscale( + map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], +): + amplitude = jnp.abs(map_2d) + amplitude_normalized = (amplitude - amplitude.min()) / ( + amplitude.max() - amplitude.min() + ) + rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1) + alpha_channel = jnp.ones_like(amplitude_normalized) + return jnp.dstack((rgb_array, alpha_channel)) + + +def rgba_image_from_2d_map( + map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], colormap: str | None = None +): + """RGBA Image from a map of 2D coordinates to values. + + Parameters: + map_2d: The 2D value map. + + Returns: + Image as a JAX array of shape (height, width, 4) + """ + if colormap == Colormap.Viridis: + return rgba_image_from_2d_map__viridis(map_2d) + if colormap == Colormap.Grayscale: + return rgba_image_from_2d_map__grayscale(map_2d) + + return rgba_image_from_2d_map__grayscale(map_2d)