Compare commits

..

2 Commits

7 changed files with 250 additions and 180 deletions

View File

@ -2,13 +2,10 @@ import io
import typing as typ import typing as typ
import bpy import bpy
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.axis as mpl_ax import matplotlib.axis as mpl_ax
import numpy as np import numpy as np
from blender_maxwell.utils import logger from blender_maxwell.utils import image_ops, logger
from .. import contracts as ct from .. import contracts as ct
from . import base from . import base
@ -18,64 +15,6 @@ log = logger.get(__name__)
AREA_TYPE = 'IMAGE_EDITOR' AREA_TYPE = 'IMAGE_EDITOR'
SPACE_TYPE = 'IMAGE_EDITOR' SPACE_TYPE = 'IMAGE_EDITOR'
# Colormap
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
VIRIDIS_COLORMAP = jnp.array([_MPL_CM(i)[:3] for i in range(512)])
####################
# - Image Functions
####################
def apply_colormap(normalized_data, colormap):
# Linear interpolation between colormap points
n_colors = colormap.shape[0]
indices = normalized_data * (n_colors - 1)
lower_idx = jnp.floor(indices).astype(jnp.int32)
upper_idx = jnp.ceil(indices).astype(jnp.int32)
alpha = indices - lower_idx
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
@jax.jit
def rgba_image_from_2d_map__viridis(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
@jax.jit
def rgba_image_from_2d_map__grayscale(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
"""RGBA Image from a map of 2D coordinates to values.
Parameters:
map_2d: Shape (width, height, value).
Returns:
Image as a JAX array of shape (height, width, 4)
"""
if colormap == 'VIRIDIS':
return rgba_image_from_2d_map__viridis(map_2d)
if colormap == 'GRAYSCALE':
return rgba_image_from_2d_map__grayscale(map_2d)
#################### ####################
# - Managed BL Image # - Managed BL Image
@ -235,7 +174,7 @@ class ManagedBLImage(base.ManagedObj):
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
): ):
self.data_to_image( self.data_to_image(
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap), lambda _: image_ops.rgba_image_from_2d_map(map_2d, colormap=colormap),
bl_select=bl_select, bl_select=bl_select,
) )

View File

@ -1,5 +1,6 @@
import typing as typ import typing as typ
import enum
import bpy import bpy
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -15,7 +16,12 @@ log = logger.get(__name__)
class ExtractDataNode(base.MaxwellSimNode): class ExtractDataNode(base.MaxwellSimNode):
"""Node for extracting data from particular objects.""" """Node for extracting data from particular objects.
Attributes:
extract_filter: Identifier for data to extract from the input.
"""
node_type = ct.NodeType.ExtractData node_type = ct.NodeType.ExtractData
bl_label = 'Extract' bl_label = 'Extract'
@ -32,11 +38,10 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
# - Properties # - Properties
#################### ####################
extract_filter: bpy.props.StringProperty( extract_filter: enum.Enum = bl_cache.BLField(
name='Extract Filter', None,
description='Data to extract from the input', prop_ui=True,
search=lambda self, _, edit_text: self.search_extract_filters(edit_text), enum_cb=lambda self, _: self.search_extract_filters(),
update=lambda self, context: self.on_prop_changed('extract_filter', context),
) )
# Sim Data # Sim Data
@ -49,41 +54,30 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
# - Computed Properties # - Computed Properties
#################### ####################
@bl_cache.cached_bl_property(persist=False) @property
def has_sim_data(self) -> bool: def has_sim_data(self) -> bool:
return ( return self.active_socket_set == 'Sim Data' and self.sim_data_monitor_nametype
self.active_socket_set == 'Sim Data'
and self.inputs['Sim Data'].is_linked
and self.sim_data_monitor_nametype
)
@bl_cache.cached_bl_property(persist=False) @property
def has_monitor_data(self) -> bool: def has_monitor_data(self) -> bool:
return ( return self.active_socket_set == 'Monitor Data' and self.monitor_data_type
self.active_socket_set == 'Monitor Data'
and self.inputs['Monitor Data'].is_linked
and self.monitor_data_type
)
#################### ####################
# - Extraction Filter Search # - Extraction Filter Search
#################### ####################
def search_extract_filters(self, edit_text: str) -> list[tuple[str, str, str]]: def search_extract_filters(self) -> list[ct.BLEnumElement]:
if self.has_sim_data: if self.has_sim_data:
return [ return [
( (monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i)
monitor_name, for i, (monitor_name, monitor_type) in enumerate(
monitor_type.removesuffix('Data'), self.sim_data_monitor_nametype.items()
) )
for monitor_name, monitor_type in self.sim_data_monitor_nametype.items()
if edit_text == '' or edit_text.lower() in monitor_name.lower()
] ]
if self.has_monitor_data: if self.has_monitor_data:
return [ return [
(component_name, f' {component_name[1]}-Pol') (component_name, component_name, f' {component_name[1]}-Pol', '', i)
for component_name in self.monitor_data_components for i, component_name in enumerate(self.monitor_data_components)
if (edit_text == '' or edit_text.lower() in component_name.lower())
] ]
return [] return []
@ -92,7 +86,7 @@ class ExtractDataNode(base.MaxwellSimNode):
# - UI # - UI
#################### ####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
col.prop(self, 'extract_filter', text='') col.prop(self, self.blfields['extract_filter'], text='')
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
if self.has_sim_data or self.has_monitor_data: if self.has_sim_data or self.has_monitor_data:
@ -108,7 +102,9 @@ class ExtractDataNode(base.MaxwellSimNode):
row = col.row() row = col.row()
box = row.box() box = row.box()
grid = box.grid_flow(row_major=True, columns=2, even_columns=True) grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
for name, desc in self.search_extract_filters(edit_text=''): for name, desc in [
(name, desc) for idname, name, desc, *_ in self.search_extract_filters()
]:
grid.label(text=name) grid.label(text=name)
grid.label(text=desc if desc else '') grid.label(text=desc if desc else '')
@ -120,6 +116,7 @@ class ExtractDataNode(base.MaxwellSimNode):
prop_name='active_socket_set', prop_name='active_socket_set',
input_sockets={'Sim Data', 'Monitor Data'}, input_sockets={'Sim Data', 'Monitor Data'},
input_sockets_optional={'Sim Data': True, 'Monitor Data': True}, input_sockets_optional={'Sim Data': True, 'Monitor Data': True},
run_on_init=True,
) )
def on_sim_data_changed(self, input_sockets: dict): def on_sim_data_changed(self, input_sockets: dict):
if input_sockets['Sim Data'] is not None: if input_sockets['Sim Data'] is not None:
@ -130,6 +127,8 @@ class ExtractDataNode(base.MaxwellSimNode):
'Sim Data' 'Sim Data'
].monitor_data.items() ].monitor_data.items()
} }
elif self.sim_data_monitor_nametype:
self.sim_data_monitor_nametype = {}
if input_sockets['Monitor Data'] is not None: if input_sockets['Monitor Data'] is not None:
# Monitor Data Type # Monitor Data Type
@ -168,18 +167,14 @@ class ExtractDataNode(base.MaxwellSimNode):
'Htheta', 'Htheta',
'Hphi', 'Hphi',
] ]
else:
if self.monitor_data_type:
self.monitor_data_type = ''
if self.monitor_data_components:
self.monitor_data_components = []
# Invalidate Computed Property Caches # Invalidate Computed Property Caches
self.has_sim_data = bl_cache.Signal.InvalidateCache self.extract_filter = bl_cache.Signal.ResetEnumItems
self.has_monitor_data = bl_cache.Signal.InvalidateCache
# Reset Extraction Filter
## The extraction filter that was set before may not be valid anymore.
## If so, simply remove it.
if self.extract_filter not in [
el[0] for el in self.search_extract_filters(edit_text='')
]:
self.extract_filter = ''
#################### ####################
# - Output: Value # - Output: Value
@ -225,16 +220,22 @@ class ExtractDataNode(base.MaxwellSimNode):
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'monitor_data_type', 'extract_filter'}, props={'monitor_data_type', 'extract_filter'},
input_sockets={'Monitor Data'}, input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
) )
def compute_extracted_data_info( def compute_extracted_data_info(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> ct.InfoFlow: # noqa: PLR0911 ) -> ct.InfoFlow:
if input_sockets['Monitor Data'] is None or not props['extract_filter']: # Retrieve XArray
if (
input_sockets['Monitor Data'] is not None
and props['extract_filter'] != 'NONE'
):
xarr = getattr(input_sockets['Monitor Data'], props['extract_filter'])
else:
return ct.InfoFlow() return ct.InfoFlow()
xarr = getattr(input_sockets['Monitor Data'], props['extract_filter']) # Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian
# XYZF: Field / Permittivity / FieldProjectionCartesian
if props['monitor_data_type'] in { if props['monitor_data_type'] in {
'Field', 'Field',
'Permittivity', 'Permittivity',
@ -257,7 +258,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# XYZT: FieldTime ## XYZT: FieldTime
if props['monitor_data_type'] == 'FieldTime': if props['monitor_data_type'] == 'FieldTime':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['x', 'y', 'z', 't'], dim_names=['x', 'y', 'z', 't'],
@ -276,7 +277,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# F: Flux ## F: Flux
if props['monitor_data_type'] == 'Flux': if props['monitor_data_type'] == 'Flux':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['f'], dim_names=['f'],
@ -289,7 +290,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# T: FluxTime ## T: FluxTime
if props['monitor_data_type'] == 'FluxTime': if props['monitor_data_type'] == 'FluxTime':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['t'], dim_names=['t'],
@ -302,7 +303,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# RThetaPhiF: FieldProjectionAngle ## RThetaPhiF: FieldProjectionAngle
if props['monitor_data_type'] == 'FieldProjectionAngle': if props['monitor_data_type'] == 'FieldProjectionAngle':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['r', 'theta', 'phi', 'f'], dim_names=['r', 'theta', 'phi', 'f'],
@ -328,7 +329,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# UxUyRF: FieldProjectionKSpace ## UxUyRF: FieldProjectionKSpace
if props['monitor_data_type'] == 'FieldProjectionKSpace': if props['monitor_data_type'] == 'FieldProjectionKSpace':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['ux', 'uy', 'r', 'f'], dim_names=['ux', 'uy', 'r', 'f'],
@ -352,7 +353,7 @@ class ExtractDataNode(base.MaxwellSimNode):
}, },
) )
# OrderxOrderyF: Diffraction ## OrderxOrderyF: Diffraction
if props['monitor_data_type'] == 'Diffraction': if props['monitor_data_type'] == 'Diffraction':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['orders_x', 'orders_y', 'f'], dim_names=['orders_x', 'orders_y', 'f'],

View File

@ -43,8 +43,8 @@ class FilterMathNode(base.MaxwellSimNode):
prop_ui=True, enum_cb=lambda self, _: self.search_operations() prop_ui=True, enum_cb=lambda self, _: self.search_operations()
) )
dim: str = bl_cache.BLField( dim: enum.Enum = bl_cache.BLField(
'', prop_ui=True, str_cb=lambda self, _, edit_text: self.search_dims(edit_text) None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
) )
dim_names: list[str] = bl_cache.BLField([]) dim_names: list[str] = bl_cache.BLField([])
@ -64,22 +64,24 @@ class FilterMathNode(base.MaxwellSimNode):
('FIX', 'del a | i≈v', 'Fix Coordinate'), ('FIX', 'del a | i≈v', 'Fix Coordinate'),
] ]
return items return [(*item, '', i) for i, item in enumerate(items)]
#################### ####################
# - Dim Search # - Dim Search
#################### ####################
def search_dims(self, edit_text: str) -> list[tuple[str, str, str]]: def search_dims(self) -> list[ct.BLEnumElement]:
if self.dim_names: if self.dim_names:
dims = [ dims = [
(dim_name, dim_name) (dim_name, dim_name, dim_name, '', i)
for dim_name in self.dim_names for i, dim_name in enumerate(self.dim_names)
if edit_text == '' or edit_text.lower() in dim_name.lower()
] ]
# Squeeze: Dimension Must Have Length=1 # Squeeze: Dimension Must Have Length=1
## We must also correct the "NUMBER" of the enum.
if self.operation == 'SQUEEZE': if self.operation == 'SQUEEZE':
return [dim for dim in dims if self.dim_lens[dim[0]] == 1] filtered_dims = [dim for dim in dims if self.dim_lens[dim[0]] == 1]
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
return dims return dims
return [] return []
@ -107,6 +109,7 @@ class FilterMathNode(base.MaxwellSimNode):
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Data': ct.FlowKind.Info},
input_sockets_optional={'Data': True}, input_sockets_optional={'Data': True},
run_on_init=True,
) )
def on_any_change(self, props: dict, input_sockets: dict): def on_any_change(self, props: dict, input_sockets: dict):
# Set Dimension Names from InfoFlow # Set Dimension Names from InfoFlow
@ -120,8 +123,8 @@ class FilterMathNode(base.MaxwellSimNode):
self.dim_names = [] self.dim_names = []
self.dim_lens = {} self.dim_lens = {}
# Reset String Searcher # Reset Enum
self.dim = bl_cache.Signal.ResetStrSearch self.dim = bl_cache.Signal.ResetEnumItems
@events.on_value_changed( @events.on_value_changed(
prop_name='dim', prop_name='dim',
@ -132,10 +135,7 @@ class FilterMathNode(base.MaxwellSimNode):
) )
def on_dim_change(self, props: dict, input_sockets: dict): def on_dim_change(self, props: dict, input_sockets: dict):
# Add/Remove Input Socket "Value" # Add/Remove Input Socket "Value"
if ( if props['active_socket_set'] == 'By Dim Value' and props['dim'] != 'NONE':
props['active_socket_set'] == 'By Dim Value'
and props['dim'] in input_sockets['Data'].dim_names
):
# Get Current and Wanted Socket Defs # Get Current and Wanted Socket Defs
current_socket_def = self.loose_input_sockets.get('Value') current_socket_def = self.loose_input_sockets.get('Value')
wanted_socket_def = sockets.SOCKET_DEFS[ wanted_socket_def = sockets.SOCKET_DEFS[
@ -167,7 +167,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Compute Bound/Free Parameters # Compute Bound/Free Parameters
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else [] func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
if props['dim']: if props['dim'] != 'NONE':
axis = info.dim_names.index(props['dim']) axis = info.dim_names.index(props['dim'])
else: else:
msg = 'Dimension cannot be empty' msg = 'Dimension cannot be empty'
@ -225,7 +225,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Compute Bound/Free Parameters # Compute Bound/Free Parameters
## Empty Dimension -> Empty InfoFlow ## Empty Dimension -> Empty InfoFlow
if props['dim']: if props['dim'] != 'NONE':
axis = info.dim_names.index(props['dim']) axis = info.dim_names.index(props['dim'])
else: else:
return ct.InfoFlow() return ct.InfoFlow()
@ -272,7 +272,7 @@ class FilterMathNode(base.MaxwellSimNode):
in [ in [
('By Dim Value', 'FIX'), ('By Dim Value', 'FIX'),
] ]
and props['dim'] and props['dim'] != 'NONE'
and input_sockets['Value'] is not None and input_sockets['Value'] is not None
): ):
# Compute IDX Corresponding to Coordinate Value # Compute IDX Corresponding to Coordinate Value

View File

@ -1,8 +1,9 @@
import enum
import typing as typ import typing as typ
import bpy import bpy
from blender_maxwell.utils import logger from blender_maxwell.utils import bl_cache, image_ops, logger
from ... import contracts as ct from ... import contracts as ct
from ... import managed_objs, sockets from ... import managed_objs, sockets
@ -12,7 +13,12 @@ log = logger.get(__name__)
class VizNode(base.MaxwellSimNode): class VizNode(base.MaxwellSimNode):
"""Node for visualizing simulation data, by querying its monitors.""" """Node for visualizing simulation data, by querying its monitors.
Attributes:
colormap: Colormap to apply to 0..1 output.
"""
node_type = ct.NodeType.Viz node_type = ct.NodeType.Viz
bl_label = 'Viz' bl_label = 'Viz'
@ -34,31 +40,25 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
## - Properties ## - Properties
##################### #####################
colormap: bpy.props.EnumProperty( colormap: image_ops.Colormap = bl_cache.BLField(
name='Colormap', image_ops.Colormap.Viridis, prop_ui=True
description='Colormap to apply to grayscale output',
items=[
('VIRIDIS', 'Viridis', 'Good default colormap'),
('GRAYSCALE', 'Grayscale', 'Barebones'),
],
default='VIRIDIS',
update=lambda self, context: self.on_prop_changed('colormap', context),
) )
##################### #####################
## - UI ## - UI
##################### #####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, 'colormap') col.prop(self, self.blfields['colormap'], text='')
##################### #####################
## - Plotting ## - Plotting
##################### #####################
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'colormap'},
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Array}, input_socket_kinds={'Data': ct.FlowKind.Array},
props={'colormap'}, input_sockets_optional={'Data': True},
stop_propagation=True, stop_propagation=True,
) )
def on_show_plot( def on_show_plot(
@ -67,6 +67,7 @@ class VizNode(base.MaxwellSimNode):
input_sockets: dict, input_sockets: dict,
props: dict, props: dict,
): ):
if input_sockets['Data'] is not None:
managed_objs['plot'].map_2d_to_image( managed_objs['plot'].map_2d_to_image(
input_sockets['Data'].values, input_sockets['Data'].values,
colormap=props['colormap'], colormap=props['colormap'],

View File

@ -1034,6 +1034,16 @@ class MaxwellSimNode(bpy.types.Node):
## Blender will automatically add .001 so that `self.name` is unique. ## Blender will automatically add .001 so that `self.name` is unique.
self.sim_node_name = self.name self.sim_node_name = self.name
# Event Methods
## Run any 'DataChanged' methods with 'run_on_init' set.
## -> Copying a node _arguably_ re-initializes the new node.
for event_method in [
event_method
for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
if event_method.callback_info.run_on_init
]:
event_method(self)
def free(self) -> None: def free(self) -> None:
"""Cleans various instance-associated data up, so the node can be cleanly deleted. """Cleans various instance-associated data up, so the node can be cleanly deleted.

View File

@ -740,7 +740,7 @@ class BLField:
str(value), str(value),
AttrType.to_name(value), AttrType.to_name(value),
AttrType.to_name(value), ## TODO: From AttrType.__doc__ AttrType.to_name(value), ## TODO: From AttrType.__doc__
AttrType.to_icon(), AttrType.to_icon(value),
i if not self._enum_many else 2**i, i if not self._enum_many else 2**i,
) )
for i, value in enumerate(list(AttrType)) for i, value in enumerate(list(AttrType))
@ -824,6 +824,12 @@ class BLField:
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None: def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
if value == Signal.ResetEnumItems: if value == Signal.ResetEnumItems:
old_items = self._safe_enum_cb(bl_instance, None)
current_items = self._enum_cb(bl_instance, None)
# Only Change if Changes Need Making
## <Proverb. 'Fun things to say in jail'. Ca 887BCE. >
if old_items != current_items:
# Set Enum to First Item # Set Enum to First Item
## Prevents the seemingly "missing" enum element bug. ## Prevents the seemingly "missing" enum element bug.
## -> Caused by the old int still trying to hang on after. ## -> Caused by the old int still trying to hang on after.
@ -831,7 +837,7 @@ class BLField:
## -> Infinite recursion if we don't check current value. ## -> Infinite recursion if we don't check current value.
## -> May cause a hiccup (chains will trigger twice) ## -> May cause a hiccup (chains will trigger twice)
## To work, there **must** be a guaranteed-available string at 0,0. ## To work, there **must** be a guaranteed-available string at 0,0.
first_old_value = self._safe_enum_cb(bl_instance, None)[0][0] first_old_value = old_items(bl_instance, None)[0][0]
current_value = self._cached_bl_property.__get__( current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__ bl_instance, bl_instance.__class__
) )
@ -847,6 +853,11 @@ class BLField:
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache) self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache)
elif value == Signal.ResetStrSearch: elif value == Signal.ResetStrSearch:
old_items = self._safe_str_cb(bl_instance, None)
current_items = self._str_cb(bl_instance, None)
# Only Change if Changes Need Making
if old_items != current_items:
# Set String to '' # Set String to ''
## Prevents the presence of an invalid value not in the new search. ## Prevents the presence of an invalid value not in the new search.
## -> Infinite recursion if we don't check current value for ''. ## -> Infinite recursion if we don't check current value for ''.

View File

@ -0,0 +1,108 @@
"""Useful image processing operations for use in the addon."""
import enum
import typing as typ
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import matplotlib
from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger
log = logger.get(__name__)
####################
# - Constants
####################
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
VIRIDIS_COLORMAP: jtyp.Float32[jtyp.Array, '512 3'] = jnp.array(
[_MPL_CM(i)[:3] for i in range(512)]
)
class Colormap(enum.StrEnum):
"""Available colormaps.
Attributes:
Viridis: Good general-purpose colormap.
Grayscale: Simple black and white mapping.
"""
Viridis = enum.auto()
Grayscale = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
return {
Colormap.Viridis: 'Viridis',
Colormap.Grayscale: 'Grayscale',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
####################
# - Colormap: (X,Y,1 -> Value) -> (X,Y,4 -> Value)
####################
def apply_colormap(
normalized_data: jtyp.Float32[jtyp.Array, 'width height 4'],
colormap: jtyp.Float32[jtyp.Array, '512 3'],
):
# Linear interpolation between colormap points
n_colors = colormap.shape[0]
indices = normalized_data * (n_colors - 1)
lower_idx = jnp.floor(indices).astype(jnp.int32)
upper_idx = jnp.ceil(indices).astype(jnp.int32)
alpha = indices - lower_idx
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
@jax.jit
def rgba_image_from_2d_map__viridis(map_2d: jtyp.Float32[jtyp.Array, 'width height 4']):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
@jax.jit
def rgba_image_from_2d_map__grayscale(
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'],
):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
alpha_channel = jnp.ones_like(amplitude_normalized)
return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_2d_map(
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], colormap: str | None = None
):
"""RGBA Image from a map of 2D coordinates to values.
Parameters:
map_2d: The 2D value map.
Returns:
Image as a JAX array of shape (height, width, 4)
"""
if colormap == Colormap.Viridis:
return rgba_image_from_2d_map__viridis(map_2d)
if colormap == Colormap.Grayscale:
return rgba_image_from_2d_map__grayscale(map_2d)
return rgba_image_from_2d_map__grayscale(map_2d)