fix: Crashes on enum changes

main
Sofus Albert Høgsbro Rose 2024-04-23 11:51:24 +02:00
parent 7fa6e3a3ec
commit e7d3ecf48e
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
6 changed files with 173 additions and 114 deletions

View File

@ -2,13 +2,10 @@ import io
import typing as typ import typing as typ
import bpy import bpy
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.axis as mpl_ax import matplotlib.axis as mpl_ax
import numpy as np import numpy as np
from blender_maxwell.utils import logger from blender_maxwell.utils import image_ops, logger
from .. import contracts as ct from .. import contracts as ct
from . import base from . import base
@ -18,64 +15,6 @@ log = logger.get(__name__)
AREA_TYPE = 'IMAGE_EDITOR' AREA_TYPE = 'IMAGE_EDITOR'
SPACE_TYPE = 'IMAGE_EDITOR' SPACE_TYPE = 'IMAGE_EDITOR'
# Colormap
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
VIRIDIS_COLORMAP = jnp.array([_MPL_CM(i)[:3] for i in range(512)])
####################
# - Image Functions
####################
def apply_colormap(normalized_data, colormap):
# Linear interpolation between colormap points
n_colors = colormap.shape[0]
indices = normalized_data * (n_colors - 1)
lower_idx = jnp.floor(indices).astype(jnp.int32)
upper_idx = jnp.ceil(indices).astype(jnp.int32)
alpha = indices - lower_idx
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
@jax.jit
def rgba_image_from_2d_map__viridis(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
@jax.jit
def rgba_image_from_2d_map__grayscale(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
"""RGBA Image from a map of 2D coordinates to values.
Parameters:
map_2d: Shape (width, height, value).
Returns:
Image as a JAX array of shape (height, width, 4)
"""
if colormap == 'VIRIDIS':
return rgba_image_from_2d_map__viridis(map_2d)
if colormap == 'GRAYSCALE':
return rgba_image_from_2d_map__grayscale(map_2d)
#################### ####################
# - Managed BL Image # - Managed BL Image
@ -235,7 +174,7 @@ class ManagedBLImage(base.ManagedObj):
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
): ):
self.data_to_image( self.data_to_image(
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap), lambda _: image_ops.rgba_image_from_2d_map(map_2d, colormap=colormap),
bl_select=bl_select, bl_select=bl_select,
) )

View File

@ -123,7 +123,7 @@ class FilterMathNode(base.MaxwellSimNode):
self.dim_names = [] self.dim_names = []
self.dim_lens = {} self.dim_lens = {}
# Reset String Searcher # Reset Enum
self.dim = bl_cache.Signal.ResetEnumItems self.dim = bl_cache.Signal.ResetEnumItems
@events.on_value_changed( @events.on_value_changed(

View File

@ -1,8 +1,9 @@
import enum
import typing as typ import typing as typ
import bpy import bpy
from blender_maxwell.utils import logger from blender_maxwell.utils import bl_cache, image_ops, logger
from ... import contracts as ct from ... import contracts as ct
from ... import managed_objs, sockets from ... import managed_objs, sockets
@ -12,7 +13,12 @@ log = logger.get(__name__)
class VizNode(base.MaxwellSimNode): class VizNode(base.MaxwellSimNode):
"""Node for visualizing simulation data, by querying its monitors.""" """Node for visualizing simulation data, by querying its monitors.
Attributes:
colormap: Colormap to apply to 0..1 output.
"""
node_type = ct.NodeType.Viz node_type = ct.NodeType.Viz
bl_label = 'Viz' bl_label = 'Viz'
@ -34,31 +40,25 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
## - Properties ## - Properties
##################### #####################
colormap: bpy.props.EnumProperty( colormap: image_ops.Colormap = bl_cache.BLField(
name='Colormap', image_ops.Colormap.Viridis, prop_ui=True
description='Colormap to apply to grayscale output',
items=[
('VIRIDIS', 'Viridis', 'Good default colormap'),
('GRAYSCALE', 'Grayscale', 'Barebones'),
],
default='VIRIDIS',
update=lambda self, context: self.on_prop_changed('colormap', context),
) )
##################### #####################
## - UI ## - UI
##################### #####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, 'colormap') col.prop(self, self.blfields['colormap'], text='')
##################### #####################
## - Plotting ## - Plotting
##################### #####################
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'colormap'},
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Array}, input_socket_kinds={'Data': ct.FlowKind.Array},
props={'colormap'}, input_sockets_optional={'Data': True},
stop_propagation=True, stop_propagation=True,
) )
def on_show_plot( def on_show_plot(
@ -67,11 +67,12 @@ class VizNode(base.MaxwellSimNode):
input_sockets: dict, input_sockets: dict,
props: dict, props: dict,
): ):
managed_objs['plot'].map_2d_to_image( if input_sockets['Data'] is not None:
input_sockets['Data'].values, managed_objs['plot'].map_2d_to_image(
colormap=props['colormap'], input_sockets['Data'].values,
bl_select=True, colormap=props['colormap'],
) bl_select=True,
)
#################### ####################

View File

@ -740,7 +740,7 @@ class BLField:
str(value), str(value),
AttrType.to_name(value), AttrType.to_name(value),
AttrType.to_name(value), ## TODO: From AttrType.__doc__ AttrType.to_name(value), ## TODO: From AttrType.__doc__
AttrType.to_icon(), AttrType.to_icon(value),
i if not self._enum_many else 2**i, i if not self._enum_many else 2**i,
) )
for i, value in enumerate(list(AttrType)) for i, value in enumerate(list(AttrType))
@ -824,42 +824,53 @@ class BLField:
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None: def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
if value == Signal.ResetEnumItems: if value == Signal.ResetEnumItems:
# Set Enum to First Item old_items = self._safe_enum_cb(bl_instance, None)
## Prevents the seemingly "missing" enum element bug. current_items = self._enum_cb(bl_instance, None)
## -> Caused by the old int still trying to hang on after.
## -> We can mitigate this by preemptively setting the enum.
## -> Infinite recursion if we don't check current value.
## -> May cause a hiccup (chains will trigger twice)
## To work, there **must** be a guaranteed-available string at 0,0.
first_old_value = self._safe_enum_cb(bl_instance, None)[0][0]
current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__
)
if current_value != first_old_value:
self._cached_bl_property.__set__(bl_instance, first_old_value)
# Pop the Cached Enum Items # Only Change if Changes Need Making
## The next time Blender asks for the enum items, it'll update. ## <Proverb. 'Fun things to say in jail'. Ca 887BCE. >
self._enum_cb_cache.pop(bl_instance.instance_id, None) if old_items != current_items:
# Set Enum to First Item
## Prevents the seemingly "missing" enum element bug.
## -> Caused by the old int still trying to hang on after.
## -> We can mitigate this by preemptively setting the enum.
## -> Infinite recursion if we don't check current value.
## -> May cause a hiccup (chains will trigger twice)
## To work, there **must** be a guaranteed-available string at 0,0.
first_old_value = old_items(bl_instance, None)[0][0]
current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__
)
if current_value != first_old_value:
self._cached_bl_property.__set__(bl_instance, first_old_value)
# Invalidate the Getter Cache # Pop the Cached Enum Items
## The next time the user runs __get__, they'll get the new value. ## The next time Blender asks for the enum items, it'll update.
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache) self._enum_cb_cache.pop(bl_instance.instance_id, None)
# Invalidate the Getter Cache
## The next time the user runs __get__, they'll get the new value.
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache)
elif value == Signal.ResetStrSearch: elif value == Signal.ResetStrSearch:
# Set String to '' old_items = self._safe_str_cb(bl_instance, None)
## Prevents the presence of an invalid value not in the new search. current_items = self._str_cb(bl_instance, None)
## -> Infinite recursion if we don't check current value for ''.
## -> May cause a hiccup (chains will trigger twice)
current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__
)
if current_value != '':
self._cached_bl_property.__set__(bl_instance, '')
# Pop the Cached String Search Items # Only Change if Changes Need Making
## The next time Blender does a str search, it'll update. if old_items != current_items:
self._str_cb_cache.pop(bl_instance.instance_id, None) # Set String to ''
## Prevents the presence of an invalid value not in the new search.
## -> Infinite recursion if we don't check current value for ''.
## -> May cause a hiccup (chains will trigger twice)
current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__
)
if current_value != '':
self._cached_bl_property.__set__(bl_instance, '')
# Pop the Cached String Search Items
## The next time Blender does a str search, it'll update.
self._str_cb_cache.pop(bl_instance.instance_id, None)
else: else:
self._cached_bl_property.__set__(bl_instance, value) self._cached_bl_property.__set__(bl_instance, value)

View File

@ -0,0 +1,108 @@
"""Useful image processing operations for use in the addon."""
import enum
import typing as typ
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import matplotlib
from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger
log = logger.get(__name__)
####################
# - Constants
####################
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
VIRIDIS_COLORMAP: jtyp.Float32[jtyp.Array, '512 3'] = jnp.array(
[_MPL_CM(i)[:3] for i in range(512)]
)
class Colormap(enum.StrEnum):
"""Available colormaps.
Attributes:
Viridis: Good general-purpose colormap.
Grayscale: Simple black and white mapping.
"""
Viridis = enum.auto()
Grayscale = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
return {
Colormap.Viridis: 'Viridis',
Colormap.Grayscale: 'Grayscale',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
####################
# - Colormap: (X,Y,1 -> Value) -> (X,Y,4 -> Value)
####################
def apply_colormap(
normalized_data: jtyp.Float32[jtyp.Array, 'width height 4'],
colormap: jtyp.Float32[jtyp.Array, '512 3'],
):
# Linear interpolation between colormap points
n_colors = colormap.shape[0]
indices = normalized_data * (n_colors - 1)
lower_idx = jnp.floor(indices).astype(jnp.int32)
upper_idx = jnp.ceil(indices).astype(jnp.int32)
alpha = indices - lower_idx
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
@jax.jit
def rgba_image_from_2d_map__viridis(map_2d: jtyp.Float32[jtyp.Array, 'width height 4']):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
@jax.jit
def rgba_image_from_2d_map__grayscale(
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'],
):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_2d_map(
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], colormap: str | None = None
):
"""RGBA Image from a map of 2D coordinates to values.
Parameters:
map_2d: The 2D value map.
Returns:
Image as a JAX array of shape (height, width, 4)
"""
if colormap == Colormap.Viridis:
return rgba_image_from_2d_map__viridis(map_2d)
if colormap == Colormap.Grayscale:
return rgba_image_from_2d_map__grayscale(map_2d)
return rgba_image_from_2d_map__grayscale(map_2d)