fix: Crashes on enum changes
parent
7fa6e3a3ec
commit
e7d3ecf48e
|
@ -2,13 +2,10 @@ import io
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib
|
||||
import matplotlib.axis as mpl_ax
|
||||
import numpy as np
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import image_ops, logger
|
||||
|
||||
from .. import contracts as ct
|
||||
from . import base
|
||||
|
@ -18,64 +15,6 @@ log = logger.get(__name__)
|
|||
AREA_TYPE = 'IMAGE_EDITOR'
|
||||
SPACE_TYPE = 'IMAGE_EDITOR'
|
||||
|
||||
# Colormap
|
||||
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
|
||||
VIRIDIS_COLORMAP = jnp.array([_MPL_CM(i)[:3] for i in range(512)])
|
||||
|
||||
|
||||
####################
|
||||
# - Image Functions
|
||||
####################
|
||||
def apply_colormap(normalized_data, colormap):
|
||||
# Linear interpolation between colormap points
|
||||
n_colors = colormap.shape[0]
|
||||
indices = normalized_data * (n_colors - 1)
|
||||
lower_idx = jnp.floor(indices).astype(jnp.int32)
|
||||
upper_idx = jnp.ceil(indices).astype(jnp.int32)
|
||||
alpha = indices - lower_idx
|
||||
|
||||
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
|
||||
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
|
||||
|
||||
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
|
||||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_2d_map__viridis(map_2d):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
|
||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||
return jnp.dstack((rgb_array, alpha_channel))
|
||||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_2d_map__grayscale(map_2d):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
|
||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||
return jnp.dstack((rgb_array, alpha_channel))
|
||||
|
||||
|
||||
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
|
||||
"""RGBA Image from a map of 2D coordinates to values.
|
||||
|
||||
Parameters:
|
||||
map_2d: Shape (width, height, value).
|
||||
|
||||
Returns:
|
||||
Image as a JAX array of shape (height, width, 4)
|
||||
"""
|
||||
if colormap == 'VIRIDIS':
|
||||
return rgba_image_from_2d_map__viridis(map_2d)
|
||||
if colormap == 'GRAYSCALE':
|
||||
return rgba_image_from_2d_map__grayscale(map_2d)
|
||||
|
||||
|
||||
####################
|
||||
# - Managed BL Image
|
||||
|
@ -235,7 +174,7 @@ class ManagedBLImage(base.ManagedObj):
|
|||
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
|
||||
):
|
||||
self.data_to_image(
|
||||
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap),
|
||||
lambda _: image_ops.rgba_image_from_2d_map(map_2d, colormap=colormap),
|
||||
bl_select=bl_select,
|
||||
)
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
self.dim_names = []
|
||||
self.dim_lens = {}
|
||||
|
||||
# Reset String Searcher
|
||||
# Reset Enum
|
||||
self.dim = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import enum
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import bl_cache, image_ops, logger
|
||||
|
||||
from ... import contracts as ct
|
||||
from ... import managed_objs, sockets
|
||||
|
@ -12,7 +13,12 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
class VizNode(base.MaxwellSimNode):
|
||||
"""Node for visualizing simulation data, by querying its monitors."""
|
||||
"""Node for visualizing simulation data, by querying its monitors.
|
||||
|
||||
Attributes:
|
||||
colormap: Colormap to apply to 0..1 output.
|
||||
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.Viz
|
||||
bl_label = 'Viz'
|
||||
|
@ -34,31 +40,25 @@ class VizNode(base.MaxwellSimNode):
|
|||
#####################
|
||||
## - Properties
|
||||
#####################
|
||||
colormap: bpy.props.EnumProperty(
|
||||
name='Colormap',
|
||||
description='Colormap to apply to grayscale output',
|
||||
items=[
|
||||
('VIRIDIS', 'Viridis', 'Good default colormap'),
|
||||
('GRAYSCALE', 'Grayscale', 'Barebones'),
|
||||
],
|
||||
default='VIRIDIS',
|
||||
update=lambda self, context: self.on_prop_changed('colormap', context),
|
||||
colormap: image_ops.Colormap = bl_cache.BLField(
|
||||
image_ops.Colormap.Viridis, prop_ui=True
|
||||
)
|
||||
|
||||
#####################
|
||||
## - UI
|
||||
#####################
|
||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
||||
col.prop(self, 'colormap')
|
||||
col.prop(self, self.blfields['colormap'], text='')
|
||||
|
||||
#####################
|
||||
## - Plotting
|
||||
#####################
|
||||
@events.on_show_plot(
|
||||
managed_objs={'plot'},
|
||||
props={'colormap'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Array},
|
||||
props={'colormap'},
|
||||
input_sockets_optional={'Data': True},
|
||||
stop_propagation=True,
|
||||
)
|
||||
def on_show_plot(
|
||||
|
@ -67,6 +67,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
input_sockets: dict,
|
||||
props: dict,
|
||||
):
|
||||
if input_sockets['Data'] is not None:
|
||||
managed_objs['plot'].map_2d_to_image(
|
||||
input_sockets['Data'].values,
|
||||
colormap=props['colormap'],
|
||||
|
|
|
@ -740,7 +740,7 @@ class BLField:
|
|||
str(value),
|
||||
AttrType.to_name(value),
|
||||
AttrType.to_name(value), ## TODO: From AttrType.__doc__
|
||||
AttrType.to_icon(),
|
||||
AttrType.to_icon(value),
|
||||
i if not self._enum_many else 2**i,
|
||||
)
|
||||
for i, value in enumerate(list(AttrType))
|
||||
|
@ -824,6 +824,12 @@ class BLField:
|
|||
|
||||
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
|
||||
if value == Signal.ResetEnumItems:
|
||||
old_items = self._safe_enum_cb(bl_instance, None)
|
||||
current_items = self._enum_cb(bl_instance, None)
|
||||
|
||||
# Only Change if Changes Need Making
|
||||
## <Proverb. 'Fun things to say in jail'. Ca 887BCE. >
|
||||
if old_items != current_items:
|
||||
# Set Enum to First Item
|
||||
## Prevents the seemingly "missing" enum element bug.
|
||||
## -> Caused by the old int still trying to hang on after.
|
||||
|
@ -831,7 +837,7 @@ class BLField:
|
|||
## -> Infinite recursion if we don't check current value.
|
||||
## -> May cause a hiccup (chains will trigger twice)
|
||||
## To work, there **must** be a guaranteed-available string at 0,0.
|
||||
first_old_value = self._safe_enum_cb(bl_instance, None)[0][0]
|
||||
first_old_value = old_items(bl_instance, None)[0][0]
|
||||
current_value = self._cached_bl_property.__get__(
|
||||
bl_instance, bl_instance.__class__
|
||||
)
|
||||
|
@ -847,6 +853,11 @@ class BLField:
|
|||
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache)
|
||||
|
||||
elif value == Signal.ResetStrSearch:
|
||||
old_items = self._safe_str_cb(bl_instance, None)
|
||||
current_items = self._str_cb(bl_instance, None)
|
||||
|
||||
# Only Change if Changes Need Making
|
||||
if old_items != current_items:
|
||||
# Set String to ''
|
||||
## Prevents the presence of an invalid value not in the new search.
|
||||
## -> Infinite recursion if we don't check current value for ''.
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
"""Useful image processing operations for use in the addon."""
|
||||
|
||||
import enum
|
||||
import typing as typ
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import matplotlib
|
||||
|
||||
from blender_maxwell import contracts as ct
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
####################
|
||||
# - Constants
|
||||
####################
|
||||
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
|
||||
VIRIDIS_COLORMAP: jtyp.Float32[jtyp.Array, '512 3'] = jnp.array(
|
||||
[_MPL_CM(i)[:3] for i in range(512)]
|
||||
)
|
||||
|
||||
|
||||
class Colormap(enum.StrEnum):
|
||||
"""Available colormaps.
|
||||
|
||||
Attributes:
|
||||
Viridis: Good general-purpose colormap.
|
||||
Grayscale: Simple black and white mapping.
|
||||
"""
|
||||
|
||||
Viridis = enum.auto()
|
||||
Grayscale = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
Colormap.Viridis: 'Viridis',
|
||||
Colormap.Grayscale: 'Grayscale',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> ct.BLIcon:
|
||||
return ''
|
||||
|
||||
|
||||
####################
|
||||
# - Colormap: (X,Y,1 -> Value) -> (X,Y,4 -> Value)
|
||||
####################
|
||||
def apply_colormap(
|
||||
normalized_data: jtyp.Float32[jtyp.Array, 'width height 4'],
|
||||
colormap: jtyp.Float32[jtyp.Array, '512 3'],
|
||||
):
|
||||
# Linear interpolation between colormap points
|
||||
n_colors = colormap.shape[0]
|
||||
indices = normalized_data * (n_colors - 1)
|
||||
lower_idx = jnp.floor(indices).astype(jnp.int32)
|
||||
upper_idx = jnp.ceil(indices).astype(jnp.int32)
|
||||
alpha = indices - lower_idx
|
||||
|
||||
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
|
||||
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
|
||||
|
||||
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
|
||||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_2d_map__viridis(map_2d: jtyp.Float32[jtyp.Array, 'width height 4']):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
|
||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||
return jnp.dstack((rgb_array, alpha_channel))
|
||||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_2d_map__grayscale(
|
||||
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'],
|
||||
):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
|
||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||
return jnp.dstack((rgb_array, alpha_channel))
|
||||
|
||||
|
||||
def rgba_image_from_2d_map(
|
||||
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], colormap: str | None = None
|
||||
):
|
||||
"""RGBA Image from a map of 2D coordinates to values.
|
||||
|
||||
Parameters:
|
||||
map_2d: The 2D value map.
|
||||
|
||||
Returns:
|
||||
Image as a JAX array of shape (height, width, 4)
|
||||
"""
|
||||
if colormap == Colormap.Viridis:
|
||||
return rgba_image_from_2d_map__viridis(map_2d)
|
||||
if colormap == Colormap.Grayscale:
|
||||
return rgba_image_from_2d_map__grayscale(map_2d)
|
||||
|
||||
return rgba_image_from_2d_map__grayscale(map_2d)
|
Loading…
Reference in New Issue