Compare commits
2 Commits
f09b58e0e7
...
e7d3ecf48e
Author | SHA1 | Date |
---|---|---|
Sofus Albert Høgsbro Rose | e7d3ecf48e | |
Sofus Albert Høgsbro Rose | 7fa6e3a3ec |
|
@ -2,13 +2,10 @@ import io
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.axis as mpl_ax
|
import matplotlib.axis as mpl_ax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import image_ops, logger
|
||||||
|
|
||||||
from .. import contracts as ct
|
from .. import contracts as ct
|
||||||
from . import base
|
from . import base
|
||||||
|
@ -18,64 +15,6 @@ log = logger.get(__name__)
|
||||||
AREA_TYPE = 'IMAGE_EDITOR'
|
AREA_TYPE = 'IMAGE_EDITOR'
|
||||||
SPACE_TYPE = 'IMAGE_EDITOR'
|
SPACE_TYPE = 'IMAGE_EDITOR'
|
||||||
|
|
||||||
# Colormap
|
|
||||||
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
|
|
||||||
VIRIDIS_COLORMAP = jnp.array([_MPL_CM(i)[:3] for i in range(512)])
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# - Image Functions
|
|
||||||
####################
|
|
||||||
def apply_colormap(normalized_data, colormap):
|
|
||||||
# Linear interpolation between colormap points
|
|
||||||
n_colors = colormap.shape[0]
|
|
||||||
indices = normalized_data * (n_colors - 1)
|
|
||||||
lower_idx = jnp.floor(indices).astype(jnp.int32)
|
|
||||||
upper_idx = jnp.ceil(indices).astype(jnp.int32)
|
|
||||||
alpha = indices - lower_idx
|
|
||||||
|
|
||||||
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
|
|
||||||
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
|
|
||||||
|
|
||||||
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def rgba_image_from_2d_map__viridis(map_2d):
|
|
||||||
amplitude = jnp.abs(map_2d)
|
|
||||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
|
||||||
amplitude.max() - amplitude.min()
|
|
||||||
)
|
|
||||||
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
|
|
||||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
|
||||||
return jnp.dstack((rgb_array, alpha_channel))
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def rgba_image_from_2d_map__grayscale(map_2d):
|
|
||||||
amplitude = jnp.abs(map_2d)
|
|
||||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
|
||||||
amplitude.max() - amplitude.min()
|
|
||||||
)
|
|
||||||
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
|
|
||||||
alpha_channel = jnp.ones_like(amplitude_normalized)
|
|
||||||
return jnp.dstack((rgb_array, alpha_channel))
|
|
||||||
|
|
||||||
|
|
||||||
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
|
|
||||||
"""RGBA Image from a map of 2D coordinates to values.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
map_2d: Shape (width, height, value).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image as a JAX array of shape (height, width, 4)
|
|
||||||
"""
|
|
||||||
if colormap == 'VIRIDIS':
|
|
||||||
return rgba_image_from_2d_map__viridis(map_2d)
|
|
||||||
if colormap == 'GRAYSCALE':
|
|
||||||
return rgba_image_from_2d_map__grayscale(map_2d)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Managed BL Image
|
# - Managed BL Image
|
||||||
|
@ -235,7 +174,7 @@ class ManagedBLImage(base.ManagedObj):
|
||||||
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
|
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
|
||||||
):
|
):
|
||||||
self.data_to_image(
|
self.data_to_image(
|
||||||
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap),
|
lambda _: image_ops.rgba_image_from_2d_map(map_2d, colormap=colormap),
|
||||||
bl_select=bl_select,
|
bl_select=bl_select,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
|
import enum
|
||||||
import bpy
|
import bpy
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
@ -15,7 +16,12 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExtractDataNode(base.MaxwellSimNode):
|
class ExtractDataNode(base.MaxwellSimNode):
|
||||||
"""Node for extracting data from particular objects."""
|
"""Node for extracting data from particular objects.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
extract_filter: Identifier for data to extract from the input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
node_type = ct.NodeType.ExtractData
|
node_type = ct.NodeType.ExtractData
|
||||||
bl_label = 'Extract'
|
bl_label = 'Extract'
|
||||||
|
@ -32,11 +38,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
extract_filter: bpy.props.StringProperty(
|
extract_filter: enum.Enum = bl_cache.BLField(
|
||||||
name='Extract Filter',
|
None,
|
||||||
description='Data to extract from the input',
|
prop_ui=True,
|
||||||
search=lambda self, _, edit_text: self.search_extract_filters(edit_text),
|
enum_cb=lambda self, _: self.search_extract_filters(),
|
||||||
update=lambda self, context: self.on_prop_changed('extract_filter', context),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sim Data
|
# Sim Data
|
||||||
|
@ -49,41 +54,30 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - Computed Properties
|
# - Computed Properties
|
||||||
####################
|
####################
|
||||||
@bl_cache.cached_bl_property(persist=False)
|
@property
|
||||||
def has_sim_data(self) -> bool:
|
def has_sim_data(self) -> bool:
|
||||||
return (
|
return self.active_socket_set == 'Sim Data' and self.sim_data_monitor_nametype
|
||||||
self.active_socket_set == 'Sim Data'
|
|
||||||
and self.inputs['Sim Data'].is_linked
|
|
||||||
and self.sim_data_monitor_nametype
|
|
||||||
)
|
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(persist=False)
|
@property
|
||||||
def has_monitor_data(self) -> bool:
|
def has_monitor_data(self) -> bool:
|
||||||
return (
|
return self.active_socket_set == 'Monitor Data' and self.monitor_data_type
|
||||||
self.active_socket_set == 'Monitor Data'
|
|
||||||
and self.inputs['Monitor Data'].is_linked
|
|
||||||
and self.monitor_data_type
|
|
||||||
)
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Extraction Filter Search
|
# - Extraction Filter Search
|
||||||
####################
|
####################
|
||||||
def search_extract_filters(self, edit_text: str) -> list[tuple[str, str, str]]:
|
def search_extract_filters(self) -> list[ct.BLEnumElement]:
|
||||||
if self.has_sim_data:
|
if self.has_sim_data:
|
||||||
return [
|
return [
|
||||||
(
|
(monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i)
|
||||||
monitor_name,
|
for i, (monitor_name, monitor_type) in enumerate(
|
||||||
monitor_type.removesuffix('Data'),
|
self.sim_data_monitor_nametype.items()
|
||||||
)
|
)
|
||||||
for monitor_name, monitor_type in self.sim_data_monitor_nametype.items()
|
|
||||||
if edit_text == '' or edit_text.lower() in monitor_name.lower()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.has_monitor_data:
|
if self.has_monitor_data:
|
||||||
return [
|
return [
|
||||||
(component_name, f'ℂ {component_name[1]}-Pol')
|
(component_name, component_name, f'ℂ {component_name[1]}-Pol', '', i)
|
||||||
for component_name in self.monitor_data_components
|
for i, component_name in enumerate(self.monitor_data_components)
|
||||||
if (edit_text == '' or edit_text.lower() in component_name.lower())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
@ -92,7 +86,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
col.prop(self, 'extract_filter', text='')
|
col.prop(self, self.blfields['extract_filter'], text='')
|
||||||
|
|
||||||
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
if self.has_sim_data or self.has_monitor_data:
|
if self.has_sim_data or self.has_monitor_data:
|
||||||
|
@ -108,7 +102,9 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
row = col.row()
|
row = col.row()
|
||||||
box = row.box()
|
box = row.box()
|
||||||
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
|
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
|
||||||
for name, desc in self.search_extract_filters(edit_text=''):
|
for name, desc in [
|
||||||
|
(name, desc) for idname, name, desc, *_ in self.search_extract_filters()
|
||||||
|
]:
|
||||||
grid.label(text=name)
|
grid.label(text=name)
|
||||||
grid.label(text=desc if desc else '')
|
grid.label(text=desc if desc else '')
|
||||||
|
|
||||||
|
@ -120,6 +116,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
prop_name='active_socket_set',
|
prop_name='active_socket_set',
|
||||||
input_sockets={'Sim Data', 'Monitor Data'},
|
input_sockets={'Sim Data', 'Monitor Data'},
|
||||||
input_sockets_optional={'Sim Data': True, 'Monitor Data': True},
|
input_sockets_optional={'Sim Data': True, 'Monitor Data': True},
|
||||||
|
run_on_init=True,
|
||||||
)
|
)
|
||||||
def on_sim_data_changed(self, input_sockets: dict):
|
def on_sim_data_changed(self, input_sockets: dict):
|
||||||
if input_sockets['Sim Data'] is not None:
|
if input_sockets['Sim Data'] is not None:
|
||||||
|
@ -130,6 +127,8 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
'Sim Data'
|
'Sim Data'
|
||||||
].monitor_data.items()
|
].monitor_data.items()
|
||||||
}
|
}
|
||||||
|
elif self.sim_data_monitor_nametype:
|
||||||
|
self.sim_data_monitor_nametype = {}
|
||||||
|
|
||||||
if input_sockets['Monitor Data'] is not None:
|
if input_sockets['Monitor Data'] is not None:
|
||||||
# Monitor Data Type
|
# Monitor Data Type
|
||||||
|
@ -168,18 +167,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
'Htheta',
|
'Htheta',
|
||||||
'Hphi',
|
'Hphi',
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
if self.monitor_data_type:
|
||||||
|
self.monitor_data_type = ''
|
||||||
|
if self.monitor_data_components:
|
||||||
|
self.monitor_data_components = []
|
||||||
|
|
||||||
# Invalidate Computed Property Caches
|
# Invalidate Computed Property Caches
|
||||||
self.has_sim_data = bl_cache.Signal.InvalidateCache
|
self.extract_filter = bl_cache.Signal.ResetEnumItems
|
||||||
self.has_monitor_data = bl_cache.Signal.InvalidateCache
|
|
||||||
|
|
||||||
# Reset Extraction Filter
|
|
||||||
## The extraction filter that was set before may not be valid anymore.
|
|
||||||
## If so, simply remove it.
|
|
||||||
if self.extract_filter not in [
|
|
||||||
el[0] for el in self.search_extract_filters(edit_text='')
|
|
||||||
]:
|
|
||||||
self.extract_filter = ''
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output: Value
|
# - Output: Value
|
||||||
|
@ -225,16 +220,22 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
kind=ct.FlowKind.Info,
|
kind=ct.FlowKind.Info,
|
||||||
props={'monitor_data_type', 'extract_filter'},
|
props={'monitor_data_type', 'extract_filter'},
|
||||||
input_sockets={'Monitor Data'},
|
input_sockets={'Monitor Data'},
|
||||||
|
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
|
||||||
)
|
)
|
||||||
def compute_extracted_data_info(
|
def compute_extracted_data_info(
|
||||||
self, props: dict, input_sockets: dict
|
self, props: dict, input_sockets: dict
|
||||||
) -> ct.InfoFlow: # noqa: PLR0911
|
) -> ct.InfoFlow:
|
||||||
if input_sockets['Monitor Data'] is None or not props['extract_filter']:
|
# Retrieve XArray
|
||||||
|
if (
|
||||||
|
input_sockets['Monitor Data'] is not None
|
||||||
|
and props['extract_filter'] != 'NONE'
|
||||||
|
):
|
||||||
|
xarr = getattr(input_sockets['Monitor Data'], props['extract_filter'])
|
||||||
|
else:
|
||||||
return ct.InfoFlow()
|
return ct.InfoFlow()
|
||||||
|
|
||||||
xarr = getattr(input_sockets['Monitor Data'], props['extract_filter'])
|
# Compute InfoFlow from XArray
|
||||||
|
## XYZF: Field / Permittivity / FieldProjectionCartesian
|
||||||
# XYZF: Field / Permittivity / FieldProjectionCartesian
|
|
||||||
if props['monitor_data_type'] in {
|
if props['monitor_data_type'] in {
|
||||||
'Field',
|
'Field',
|
||||||
'Permittivity',
|
'Permittivity',
|
||||||
|
@ -257,7 +258,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# XYZT: FieldTime
|
## XYZT: FieldTime
|
||||||
if props['monitor_data_type'] == 'FieldTime':
|
if props['monitor_data_type'] == 'FieldTime':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['x', 'y', 'z', 't'],
|
dim_names=['x', 'y', 'z', 't'],
|
||||||
|
@ -276,7 +277,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# F: Flux
|
## F: Flux
|
||||||
if props['monitor_data_type'] == 'Flux':
|
if props['monitor_data_type'] == 'Flux':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['f'],
|
dim_names=['f'],
|
||||||
|
@ -289,7 +290,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# T: FluxTime
|
## T: FluxTime
|
||||||
if props['monitor_data_type'] == 'FluxTime':
|
if props['monitor_data_type'] == 'FluxTime':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['t'],
|
dim_names=['t'],
|
||||||
|
@ -302,7 +303,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# RThetaPhiF: FieldProjectionAngle
|
## RThetaPhiF: FieldProjectionAngle
|
||||||
if props['monitor_data_type'] == 'FieldProjectionAngle':
|
if props['monitor_data_type'] == 'FieldProjectionAngle':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['r', 'theta', 'phi', 'f'],
|
dim_names=['r', 'theta', 'phi', 'f'],
|
||||||
|
@ -328,7 +329,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# UxUyRF: FieldProjectionKSpace
|
## UxUyRF: FieldProjectionKSpace
|
||||||
if props['monitor_data_type'] == 'FieldProjectionKSpace':
|
if props['monitor_data_type'] == 'FieldProjectionKSpace':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['ux', 'uy', 'r', 'f'],
|
dim_names=['ux', 'uy', 'r', 'f'],
|
||||||
|
@ -352,7 +353,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# OrderxOrderyF: Diffraction
|
## OrderxOrderyF: Diffraction
|
||||||
if props['monitor_data_type'] == 'Diffraction':
|
if props['monitor_data_type'] == 'Diffraction':
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dim_names=['orders_x', 'orders_y', 'f'],
|
dim_names=['orders_x', 'orders_y', 'f'],
|
||||||
|
|
|
@ -43,8 +43,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||||
)
|
)
|
||||||
|
|
||||||
dim: str = bl_cache.BLField(
|
dim: enum.Enum = bl_cache.BLField(
|
||||||
'', prop_ui=True, str_cb=lambda self, _, edit_text: self.search_dims(edit_text)
|
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
||||||
)
|
)
|
||||||
|
|
||||||
dim_names: list[str] = bl_cache.BLField([])
|
dim_names: list[str] = bl_cache.BLField([])
|
||||||
|
@ -64,22 +64,24 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
('FIX', 'del a | i≈v', 'Fix Coordinate'),
|
('FIX', 'del a | i≈v', 'Fix Coordinate'),
|
||||||
]
|
]
|
||||||
|
|
||||||
return items
|
return [(*item, '', i) for i, item in enumerate(items)]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Dim Search
|
# - Dim Search
|
||||||
####################
|
####################
|
||||||
def search_dims(self, edit_text: str) -> list[tuple[str, str, str]]:
|
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||||
if self.dim_names:
|
if self.dim_names:
|
||||||
dims = [
|
dims = [
|
||||||
(dim_name, dim_name)
|
(dim_name, dim_name, dim_name, '', i)
|
||||||
for dim_name in self.dim_names
|
for i, dim_name in enumerate(self.dim_names)
|
||||||
if edit_text == '' or edit_text.lower() in dim_name.lower()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Squeeze: Dimension Must Have Length=1
|
# Squeeze: Dimension Must Have Length=1
|
||||||
|
## We must also correct the "NUMBER" of the enum.
|
||||||
if self.operation == 'SQUEEZE':
|
if self.operation == 'SQUEEZE':
|
||||||
return [dim for dim in dims if self.dim_lens[dim[0]] == 1]
|
filtered_dims = [dim for dim in dims if self.dim_lens[dim[0]] == 1]
|
||||||
|
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
|
||||||
|
|
||||||
return dims
|
return dims
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -107,6 +109,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
input_sockets={'Data'},
|
input_sockets={'Data'},
|
||||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||||
input_sockets_optional={'Data': True},
|
input_sockets_optional={'Data': True},
|
||||||
|
run_on_init=True,
|
||||||
)
|
)
|
||||||
def on_any_change(self, props: dict, input_sockets: dict):
|
def on_any_change(self, props: dict, input_sockets: dict):
|
||||||
# Set Dimension Names from InfoFlow
|
# Set Dimension Names from InfoFlow
|
||||||
|
@ -120,8 +123,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
self.dim_names = []
|
self.dim_names = []
|
||||||
self.dim_lens = {}
|
self.dim_lens = {}
|
||||||
|
|
||||||
# Reset String Searcher
|
# Reset Enum
|
||||||
self.dim = bl_cache.Signal.ResetStrSearch
|
self.dim = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
prop_name='dim',
|
prop_name='dim',
|
||||||
|
@ -132,10 +135,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def on_dim_change(self, props: dict, input_sockets: dict):
|
def on_dim_change(self, props: dict, input_sockets: dict):
|
||||||
# Add/Remove Input Socket "Value"
|
# Add/Remove Input Socket "Value"
|
||||||
if (
|
if props['active_socket_set'] == 'By Dim Value' and props['dim'] != 'NONE':
|
||||||
props['active_socket_set'] == 'By Dim Value'
|
|
||||||
and props['dim'] in input_sockets['Data'].dim_names
|
|
||||||
):
|
|
||||||
# Get Current and Wanted Socket Defs
|
# Get Current and Wanted Socket Defs
|
||||||
current_socket_def = self.loose_input_sockets.get('Value')
|
current_socket_def = self.loose_input_sockets.get('Value')
|
||||||
wanted_socket_def = sockets.SOCKET_DEFS[
|
wanted_socket_def = sockets.SOCKET_DEFS[
|
||||||
|
@ -167,7 +167,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
# Compute Bound/Free Parameters
|
# Compute Bound/Free Parameters
|
||||||
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
|
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
|
||||||
if props['dim']:
|
if props['dim'] != 'NONE':
|
||||||
axis = info.dim_names.index(props['dim'])
|
axis = info.dim_names.index(props['dim'])
|
||||||
else:
|
else:
|
||||||
msg = 'Dimension cannot be empty'
|
msg = 'Dimension cannot be empty'
|
||||||
|
@ -225,7 +225,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
# Compute Bound/Free Parameters
|
# Compute Bound/Free Parameters
|
||||||
## Empty Dimension -> Empty InfoFlow
|
## Empty Dimension -> Empty InfoFlow
|
||||||
if props['dim']:
|
if props['dim'] != 'NONE':
|
||||||
axis = info.dim_names.index(props['dim'])
|
axis = info.dim_names.index(props['dim'])
|
||||||
else:
|
else:
|
||||||
return ct.InfoFlow()
|
return ct.InfoFlow()
|
||||||
|
@ -272,7 +272,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
in [
|
in [
|
||||||
('By Dim Value', 'FIX'),
|
('By Dim Value', 'FIX'),
|
||||||
]
|
]
|
||||||
and props['dim']
|
and props['dim'] != 'NONE'
|
||||||
and input_sockets['Value'] is not None
|
and input_sockets['Value'] is not None
|
||||||
):
|
):
|
||||||
# Compute IDX Corresponding to Coordinate Value
|
# Compute IDX Corresponding to Coordinate Value
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
import enum
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import bl_cache, image_ops, logger
|
||||||
|
|
||||||
from ... import contracts as ct
|
from ... import contracts as ct
|
||||||
from ... import managed_objs, sockets
|
from ... import managed_objs, sockets
|
||||||
|
@ -12,7 +13,12 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VizNode(base.MaxwellSimNode):
|
class VizNode(base.MaxwellSimNode):
|
||||||
"""Node for visualizing simulation data, by querying its monitors."""
|
"""Node for visualizing simulation data, by querying its monitors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
colormap: Colormap to apply to 0..1 output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
node_type = ct.NodeType.Viz
|
node_type = ct.NodeType.Viz
|
||||||
bl_label = 'Viz'
|
bl_label = 'Viz'
|
||||||
|
@ -34,31 +40,25 @@ class VizNode(base.MaxwellSimNode):
|
||||||
#####################
|
#####################
|
||||||
## - Properties
|
## - Properties
|
||||||
#####################
|
#####################
|
||||||
colormap: bpy.props.EnumProperty(
|
colormap: image_ops.Colormap = bl_cache.BLField(
|
||||||
name='Colormap',
|
image_ops.Colormap.Viridis, prop_ui=True
|
||||||
description='Colormap to apply to grayscale output',
|
|
||||||
items=[
|
|
||||||
('VIRIDIS', 'Viridis', 'Good default colormap'),
|
|
||||||
('GRAYSCALE', 'Grayscale', 'Barebones'),
|
|
||||||
],
|
|
||||||
default='VIRIDIS',
|
|
||||||
update=lambda self, context: self.on_prop_changed('colormap', context),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
## - UI
|
## - UI
|
||||||
#####################
|
#####################
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
||||||
col.prop(self, 'colormap')
|
col.prop(self, self.blfields['colormap'], text='')
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
## - Plotting
|
## - Plotting
|
||||||
#####################
|
#####################
|
||||||
@events.on_show_plot(
|
@events.on_show_plot(
|
||||||
managed_objs={'plot'},
|
managed_objs={'plot'},
|
||||||
|
props={'colormap'},
|
||||||
input_sockets={'Data'},
|
input_sockets={'Data'},
|
||||||
input_socket_kinds={'Data': ct.FlowKind.Array},
|
input_socket_kinds={'Data': ct.FlowKind.Array},
|
||||||
props={'colormap'},
|
input_sockets_optional={'Data': True},
|
||||||
stop_propagation=True,
|
stop_propagation=True,
|
||||||
)
|
)
|
||||||
def on_show_plot(
|
def on_show_plot(
|
||||||
|
@ -67,6 +67,7 @@ class VizNode(base.MaxwellSimNode):
|
||||||
input_sockets: dict,
|
input_sockets: dict,
|
||||||
props: dict,
|
props: dict,
|
||||||
):
|
):
|
||||||
|
if input_sockets['Data'] is not None:
|
||||||
managed_objs['plot'].map_2d_to_image(
|
managed_objs['plot'].map_2d_to_image(
|
||||||
input_sockets['Data'].values,
|
input_sockets['Data'].values,
|
||||||
colormap=props['colormap'],
|
colormap=props['colormap'],
|
||||||
|
|
|
@ -1034,6 +1034,16 @@ class MaxwellSimNode(bpy.types.Node):
|
||||||
## Blender will automatically add .001 so that `self.name` is unique.
|
## Blender will automatically add .001 so that `self.name` is unique.
|
||||||
self.sim_node_name = self.name
|
self.sim_node_name = self.name
|
||||||
|
|
||||||
|
# Event Methods
|
||||||
|
## Run any 'DataChanged' methods with 'run_on_init' set.
|
||||||
|
## -> Copying a node _arguably_ re-initializes the new node.
|
||||||
|
for event_method in [
|
||||||
|
event_method
|
||||||
|
for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
|
||||||
|
if event_method.callback_info.run_on_init
|
||||||
|
]:
|
||||||
|
event_method(self)
|
||||||
|
|
||||||
def free(self) -> None:
|
def free(self) -> None:
|
||||||
"""Cleans various instance-associated data up, so the node can be cleanly deleted.
|
"""Cleans various instance-associated data up, so the node can be cleanly deleted.
|
||||||
|
|
||||||
|
|
|
@ -740,7 +740,7 @@ class BLField:
|
||||||
str(value),
|
str(value),
|
||||||
AttrType.to_name(value),
|
AttrType.to_name(value),
|
||||||
AttrType.to_name(value), ## TODO: From AttrType.__doc__
|
AttrType.to_name(value), ## TODO: From AttrType.__doc__
|
||||||
AttrType.to_icon(),
|
AttrType.to_icon(value),
|
||||||
i if not self._enum_many else 2**i,
|
i if not self._enum_many else 2**i,
|
||||||
)
|
)
|
||||||
for i, value in enumerate(list(AttrType))
|
for i, value in enumerate(list(AttrType))
|
||||||
|
@ -824,6 +824,12 @@ class BLField:
|
||||||
|
|
||||||
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
|
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
|
||||||
if value == Signal.ResetEnumItems:
|
if value == Signal.ResetEnumItems:
|
||||||
|
old_items = self._safe_enum_cb(bl_instance, None)
|
||||||
|
current_items = self._enum_cb(bl_instance, None)
|
||||||
|
|
||||||
|
# Only Change if Changes Need Making
|
||||||
|
## <Proverb. 'Fun things to say in jail'. Ca 887BCE. >
|
||||||
|
if old_items != current_items:
|
||||||
# Set Enum to First Item
|
# Set Enum to First Item
|
||||||
## Prevents the seemingly "missing" enum element bug.
|
## Prevents the seemingly "missing" enum element bug.
|
||||||
## -> Caused by the old int still trying to hang on after.
|
## -> Caused by the old int still trying to hang on after.
|
||||||
|
@ -831,7 +837,7 @@ class BLField:
|
||||||
## -> Infinite recursion if we don't check current value.
|
## -> Infinite recursion if we don't check current value.
|
||||||
## -> May cause a hiccup (chains will trigger twice)
|
## -> May cause a hiccup (chains will trigger twice)
|
||||||
## To work, there **must** be a guaranteed-available string at 0,0.
|
## To work, there **must** be a guaranteed-available string at 0,0.
|
||||||
first_old_value = self._safe_enum_cb(bl_instance, None)[0][0]
|
first_old_value = old_items(bl_instance, None)[0][0]
|
||||||
current_value = self._cached_bl_property.__get__(
|
current_value = self._cached_bl_property.__get__(
|
||||||
bl_instance, bl_instance.__class__
|
bl_instance, bl_instance.__class__
|
||||||
)
|
)
|
||||||
|
@ -847,6 +853,11 @@ class BLField:
|
||||||
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache)
|
self._cached_bl_property.__set__(bl_instance, Signal.InvalidateCache)
|
||||||
|
|
||||||
elif value == Signal.ResetStrSearch:
|
elif value == Signal.ResetStrSearch:
|
||||||
|
old_items = self._safe_str_cb(bl_instance, None)
|
||||||
|
current_items = self._str_cb(bl_instance, None)
|
||||||
|
|
||||||
|
# Only Change if Changes Need Making
|
||||||
|
if old_items != current_items:
|
||||||
# Set String to ''
|
# Set String to ''
|
||||||
## Prevents the presence of an invalid value not in the new search.
|
## Prevents the presence of an invalid value not in the new search.
|
||||||
## -> Infinite recursion if we don't check current value for ''.
|
## -> Infinite recursion if we don't check current value for ''.
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""Useful image processing operations for use in the addon."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxtyping as jtyp
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
from blender_maxwell import contracts as ct
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Constants
|
||||||
|
####################
|
||||||
|
_MPL_CM = matplotlib.cm.get_cmap('viridis', 512)
|
||||||
|
VIRIDIS_COLORMAP: jtyp.Float32[jtyp.Array, '512 3'] = jnp.array(
|
||||||
|
[_MPL_CM(i)[:3] for i in range(512)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Colormap(enum.StrEnum):
|
||||||
|
"""Available colormaps.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
Viridis: Good general-purpose colormap.
|
||||||
|
Grayscale: Simple black and white mapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Viridis = enum.auto()
|
||||||
|
Grayscale = enum.auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_name(value: typ.Self) -> str:
|
||||||
|
return {
|
||||||
|
Colormap.Viridis: 'Viridis',
|
||||||
|
Colormap.Grayscale: 'Grayscale',
|
||||||
|
}[value]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_icon(value: typ.Self) -> ct.BLIcon:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Colormap: (X,Y,1 -> Value) -> (X,Y,4 -> Value)
|
||||||
|
####################
|
||||||
|
def apply_colormap(
|
||||||
|
normalized_data: jtyp.Float32[jtyp.Array, 'width height 4'],
|
||||||
|
colormap: jtyp.Float32[jtyp.Array, '512 3'],
|
||||||
|
):
|
||||||
|
# Linear interpolation between colormap points
|
||||||
|
n_colors = colormap.shape[0]
|
||||||
|
indices = normalized_data * (n_colors - 1)
|
||||||
|
lower_idx = jnp.floor(indices).astype(jnp.int32)
|
||||||
|
upper_idx = jnp.ceil(indices).astype(jnp.int32)
|
||||||
|
alpha = indices - lower_idx
|
||||||
|
|
||||||
|
lower_colors = jax.vmap(lambda i: colormap[i])(lower_idx)
|
||||||
|
upper_colors = jax.vmap(lambda i: colormap[i])(upper_idx)
|
||||||
|
|
||||||
|
return (1 - alpha)[..., None] * lower_colors + alpha[..., None] * upper_colors
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def rgba_image_from_2d_map__viridis(map_2d: jtyp.Float32[jtyp.Array, 'width height 4']):
|
||||||
|
amplitude = jnp.abs(map_2d)
|
||||||
|
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||||
|
amplitude.max() - amplitude.min()
|
||||||
|
)
|
||||||
|
rgb_array = apply_colormap(amplitude_normalized, VIRIDIS_COLORMAP)
|
||||||
|
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||||
|
return jnp.dstack((rgb_array, alpha_channel))
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def rgba_image_from_2d_map__grayscale(
|
||||||
|
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'],
|
||||||
|
):
|
||||||
|
amplitude = jnp.abs(map_2d)
|
||||||
|
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||||
|
amplitude.max() - amplitude.min()
|
||||||
|
)
|
||||||
|
rgb_array = jnp.stack([amplitude_normalized] * 3, axis=-1)
|
||||||
|
alpha_channel = jnp.ones_like(amplitude_normalized)
|
||||||
|
return jnp.dstack((rgb_array, alpha_channel))
|
||||||
|
|
||||||
|
|
||||||
|
def rgba_image_from_2d_map(
|
||||||
|
map_2d: jtyp.Float32[jtyp.Array, 'width height 4'], colormap: str | None = None
|
||||||
|
):
|
||||||
|
"""RGBA Image from a map of 2D coordinates to values.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
map_2d: The 2D value map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image as a JAX array of shape (height, width, 4)
|
||||||
|
"""
|
||||||
|
if colormap == Colormap.Viridis:
|
||||||
|
return rgba_image_from_2d_map__viridis(map_2d)
|
||||||
|
if colormap == Colormap.Grayscale:
|
||||||
|
return rgba_image_from_2d_map__grayscale(map_2d)
|
||||||
|
|
||||||
|
return rgba_image_from_2d_map__grayscale(map_2d)
|
Loading…
Reference in New Issue