feat: Complete matplotlib plotting system.

The Viz node now detects the shape of the data, and presents compatible
plot options.
Not all are implemented, but a few quite important ones are.

Additionally, a number of dataflow-related bugs were investigated and
fixed. A few were truly damaging, but many simply resulted in gross
inefficiencies - we must be careful declaring BLFields that are updated
in hot loops!

Moreover, it is exceptionally easy to add more as needed, as we analyze
more and more sims.
The only limit is `matplotlib`, which is... well, yeah.

Due to the BLField work, the dynamicness of the Viz node is quite
under control, so there will not be any critical issues there.

The plotting lags (70ms total in the hot loop), but that's actually
entirely fixeable.
It's also entirely the `managed_bl_image`'s fault.
Fixing these inefficiencies will also make Tidy3D's builtin plots
near-realtime, incidentally.

We profiled the following currently:
- 25ms: Creating `fig = plt.subplots`. We can reuse fig per-managed
  image.
- 43ms: The BytesIO roundtrip, including `savefig`. We can instead use
  the Agg backend, `fig.canvas.draw()`, and a `np.frombuffer` to both
  plot directly to the memory location,
- ~3ms: Actual plotting functions in `image_ops`. They are seriously fast.
- ~0ms: Blitting pixels to the Blender image - this was optimized in
  4.1, and it shows; the time to copy the data over is essentially nothing.
main
Sofus Albert Høgsbro Rose 2024-04-23 19:27:45 +02:00
parent e7d3ecf48e
commit a3defd3c1c
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
12 changed files with 666 additions and 88 deletions

View File

@ -126,6 +126,10 @@ class ArrayFlow:
def __len__(self) -> int:
return len(self.values)
@functools.cached_property
def mathtype(self) -> spux.MathType:
return spux.MathType.from_pytype(type(self.values.item(0)))
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value.
@ -437,6 +441,27 @@ class LazyArrayRangeFlow:
key=lambda sym: sym.name,
)
@functools.cached_property
def mathtype(self) -> spux.MathType:
# Get Start Mathtype
if isinstance(self.start, spux.SympyType):
start_mathtype = spux.MathType.from_expr(self.start)
else:
start_mathtype = spux.MathType.from_pytype(self.start)
# Get Stop Mathtype
if isinstance(self.stop, spux.SympyType):
stop_mathtype = spux.MathType.from_expr(type(self.stop))
else:
stop_mathtype = spux.MathType.from_pytype(type(self.stop))
# Check Equal
if start_mathtype != stop_mathtype:
msg = "Mathtypes of start and stop don't agree. Please fix!"
raise ValueError(msg)
return start_mathtype
def __len__(self):
return self.steps
@ -688,4 +713,31 @@ class InfoFlow:
default_factory=dict
) ## TODO: Rename to dim_idxs
## TODO: Validation, esp. length of dims. Pydantic?
@functools.cached_property
def dim_lens(self) -> dict[str, int]:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_mathtypes(self) -> dict[str, int]:
return {
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property
def dim_units(self) -> dict[str, int]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_idx_arrays(self) -> list[ArrayFlow]:
return [
dim_idx.realize().values
if isinstance(dim_idx, LazyArrayRangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
# Output Information
output_names: list[str] = dataclasses.field(default_factory=list)
output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict)
output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict)

View File

@ -1,4 +1,5 @@
import io
import time
import typing as typ
import bpy
@ -205,61 +206,57 @@ class ManagedBLImage(base.ManagedObj):
dpi: int | None = None,
bl_select: bool = False,
):
# time_start = time.perf_counter()
times = [time.perf_counter()]
import matplotlib.pyplot as plt
# log.debug('Imported PyPlot (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
# Compute Plot Dimensions
aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
self.gen_image_geometry(width_inches, height_inches, dpi)
)
# log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
# log.debug(
# 'Creating MPL Axes (aspect=%f, width=%f, height=%f)',
# aspect_ratio,
# _width_inches,
# _height_inches,
# )
# Create MPL Figure, Axes, and Compute Figure Geometry
fig, ax = plt.subplots(
figsize=[_width_inches, _height_inches],
dpi=_dpi,
)
# log.debug('Created MPL Axes (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
ax.set_aspect(aspect_ratio)
times.append(time.perf_counter() - times[0])
cmp_width_px, cmp_height_px = fig.canvas.get_width_height()
## Use computed pixel w/h to preempt off-by-one size errors.
times.append(time.perf_counter() - times[0])
ax.set_aspect('auto') ## Workaround aspect-ratio bugs
# log.debug('Set MPL Aspect (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
# Plot w/User Parameter
func_plotter(ax)
# log.debug('User Plot Function (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
# Save Figure to BytesIO
with io.BytesIO() as buff:
# log.debug('Made BytesIO (%f)', time.perf_counter() - time_start)
fig.savefig(buff, format='raw', dpi=dpi)
# log.debug('Saved Figure to BytesIO (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
buff.seek(0)
image_data = np.frombuffer(
buff.getvalue(),
dtype=np.uint8,
).reshape([cmp_height_px, cmp_width_px, -1])
# log.debug('Set Image Data (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
image_data = np.flipud(image_data).astype(np.float32) / 255
# log.debug('Flipped Image Data (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
plt.close(fig)
# Optimized Write to Blender Image
bl_image = self.bl_image(cmp_width_px, cmp_height_px, 'RGBA', 'uint8')
# log.debug('Made BL Image (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
bl_image.pixels.foreach_set(image_data.ravel())
# log.debug('Set BL Image Pixels (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
bl_image.update()
# log.debug('Updated BL Image (%f)', time.perf_counter() - time_start)
times.append(time.perf_counter() - times[0])
if bl_select:
self.bl_select()
times.append(time.perf_counter() - times[0])
#log.critical('Timing of MPL Plot: %s', str(times))

View File

@ -1,12 +1,13 @@
import enum
import typing as typ
import enum
import bpy
import jax
import jax.numpy as jnp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct
from ... import sockets
@ -45,11 +46,13 @@ class ExtractDataNode(base.MaxwellSimNode):
)
# Sim Data
sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField({})
sim_data_monitor_nametype: dict[str, str] = bl_cache.BLField(
{}, use_prop_update=False
)
# Monitor Data
monitor_data_type: str = bl_cache.BLField('')
monitor_data_components: list[str] = bl_cache.BLField([])
monitor_data_type: str = bl_cache.BLField('', use_prop_update=False)
monitor_data_components: list[str] = bl_cache.BLField([], use_prop_update=False)
####################
# - Computed Properties
@ -177,7 +180,7 @@ class ExtractDataNode(base.MaxwellSimNode):
self.extract_filter = bl_cache.Signal.ResetEnumItems
####################
# - Output: Value
# - Output: Sim Data -> Monitor Data
####################
@events.computes_output_socket(
'Monitor Data',
@ -186,34 +189,52 @@ class ExtractDataNode(base.MaxwellSimNode):
input_sockets={'Sim Data'},
)
def compute_monitor_data(self, props: dict, input_sockets: dict):
if input_sockets['Sim Data'] is not None and props['extract_filter'] != 'NONE':
return input_sockets['Sim Data'].monitor_data[props['extract_filter']]
return None
####################
# - Output: Monitor Data -> Data
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
props={'extract_filter'},
input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
)
def compute_data(self, props: dict, input_sockets: dict) -> jax.Array:
xarray_data = getattr(input_sockets['Monitor Data'], props['extract_filter'])
return jnp.array(xarray_data.data) ## TODO: Can it be done without a copy?
def compute_data(self, props: dict, input_sockets: dict) -> jax.Array | None:
if (
input_sockets['Monitor Data'] is not None
and props['extract_filter'] != 'NONE'
):
xarray_data = getattr(
input_sockets['Monitor Data'], props['extract_filter']
)
return jnp.array(xarray_data.data)
## TODO: Let the array itself have its output unit too!
return None
####################
# - Output: LazyValueFunc
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
output_sockets={'Data'},
output_socket_kinds={'Data': ct.FlowKind.Array},
)
def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.LazyValueFuncFlow:
def compute_extracted_data_lazy(
self, output_sockets: dict
) -> ct.LazyValueFuncFlow | None:
if output_sockets['Data'] is not None:
return ct.LazyValueFuncFlow(
func=lambda: output_sockets['Data'], supports_jax=True
)
return None
####################
# - Output: Info
# - Auxiliary: Monitor Data -> Data
####################
@events.computes_output_socket(
'Data',
@ -221,6 +242,7 @@ class ExtractDataNode(base.MaxwellSimNode):
props={'monitor_data_type', 'extract_filter'},
input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
)
def compute_extracted_data_info(
self, props: dict, input_sockets: dict
@ -234,12 +256,16 @@ class ExtractDataNode(base.MaxwellSimNode):
else:
return ct.InfoFlow()
info_output_names = {
'output_names': [props['extract_filter']],
}
# Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian
if props['monitor_data_type'] in {
'Field',
'Permittivity',
'FieldProjectionCartesian',
#'FieldProjectionCartesian',
}:
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 'f'],
@ -256,6 +282,13 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
output_units={
props['extract_filter']: spu.volt / spu.micrometer
if props['monitor_data_type'] == 'Field'
else None
},
)
## XYZT: FieldTime
@ -275,6 +308,17 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
output_units={
props['extract_filter']: (
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
if props['monitor_data_type'] == 'Field'
else None
},
)
## F: Flux
@ -288,6 +332,9 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={props['extract_filter']: spu.watt},
)
## T: FluxTime
@ -301,6 +348,9 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={props['extract_filter']: spu.watt},
)
## RThetaPhiF: FieldProjectionAngle
@ -327,6 +377,15 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
)
## UxUyRF: FieldProjectionKSpace
@ -351,6 +410,15 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
)
## OrderxOrderyF: Diffraction
@ -372,6 +440,15 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
)
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'

View File

@ -47,8 +47,9 @@ class FilterMathNode(base.MaxwellSimNode):
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
)
dim_names: list[str] = bl_cache.BLField([])
dim_lens: dict[str, int] = bl_cache.BLField({})
@property
def _info(self) -> ct.InfoFlow:
return self._compute_input('Data', kind=ct.FlowKind.Info)
####################
# - Operation Search
@ -70,16 +71,16 @@ class FilterMathNode(base.MaxwellSimNode):
# - Dim Search
####################
def search_dims(self) -> list[ct.BLEnumElement]:
if self.dim_names:
if (info := self._info).dim_names:
dims = [
(dim_name, dim_name, dim_name, '', i)
for i, dim_name in enumerate(self.dim_names)
for i, dim_name in enumerate(info.dim_names)
]
# Squeeze: Dimension Must Have Length=1
## We must also correct the "NUMBER" of the enum.
if self.operation == 'SQUEEZE':
filtered_dims = [dim for dim in dims if self.dim_lens[dim[0]] == 1]
filtered_dims = [dim for dim in dims if info.dim_lens[dim[0]] == 1]
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
return dims
@ -90,7 +91,7 @@ class FilterMathNode(base.MaxwellSimNode):
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
if self.dim_names:
if self._info.dim_names:
layout.prop(self, self.blfields['dim'], text='')
####################
@ -98,52 +99,49 @@ class FilterMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
prop_name='active_socket_set',
run_on_init=True,
)
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
socket_name={'Data'},
prop_name={'active_socket_set'},
socket_name='Data',
prop_name='active_socket_set',
props={'active_socket_set'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
input_sockets_optional={'Data': True},
run_on_init=True,
# run_on_init=True,
)
def on_any_change(self, props: dict, input_sockets: dict):
# Set Dimension Names from InfoFlow
if input_sockets['Data'].dim_names:
self.dim_names = input_sockets['Data'].dim_names
self.dim_lens = {
dim_name: len(dim_idx)
for dim_name, dim_idx in input_sockets['Data'].dim_idx.items()
}
else:
self.dim_names = []
self.dim_lens = {}
# Reset Enum
self.dim = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
socket_name='Data',
prop_name='dim',
props={'active_socket_set', 'dim'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
input_sockets_optional={'Data': True},
# run_on_init=True,
)
def on_dim_change(self, props: dict, input_sockets: dict):
# Add/Remove Input Socket "Value"
if props['active_socket_set'] == 'By Dim Value' and props['dim'] != 'NONE':
if (
input_sockets['Data'] != ct.InfoFlow()
and props['active_socket_set'] == 'By Dim Value'
and props['dim'] != 'NONE'
):
# Get Current and Wanted Socket Defs
current_socket_def = self.loose_input_sockets.get('Value')
current_bl_socket = self.loose_input_sockets.get('Value')
wanted_socket_def = sockets.SOCKET_DEFS[
ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit)
]
# Determine Whether to Declare New Loose Input SOcket
if current_socket_def is None or current_socket_def != wanted_socket_def:
if (
current_bl_socket is None
or sockets.SOCKET_DEFS[current_bl_socket.socket_type]
!= wanted_socket_def
):
self.loose_input_sockets = {
'Value': wanted_socket_def(),
}
@ -225,7 +223,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Compute Bound/Free Parameters
## Empty Dimension -> Empty InfoFlow
if props['dim'] != 'NONE':
if input_sockets['Data'] != ct.InfoFlow() and props['dim'] != 'NONE':
axis = info.dim_names.index(props['dim'])
else:
return ct.InfoFlow()
@ -243,6 +241,9 @@ class FilterMathNode(base.MaxwellSimNode):
for dim_name, dim_idx in info.dim_idx.items()
if dim_name != props['dim']
},
output_names=info.output_names,
output_mathtypes=info.output_mathtypes,
output_units=info.output_units,
)
# Fallback to Empty InfoFlow

View File

@ -6,7 +6,8 @@ import jax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger, bl_cache
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
@ -143,7 +144,7 @@ class MapMathNode(base.MaxwellSimNode):
'SINC': lambda data: jnp.sinc(data),
},
'By Vector': {
'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1),
'NORM_2': lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
},
'By Matrix': {
# Matrix -> Number
@ -196,11 +197,34 @@ class MapMathNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
)
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow:
return input_sockets['Data']
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
info = input_sockets['Data']
# Complex -> Real
if props['active_socket_set'] == 'By Element' and props['operation'] in [
'REAL',
'IMAG',
'ABS',
]:
return ct.InfoFlow(
dim_names=info.dim_names,
dim_idx=info.dim_idx,
output_names=info.output_names,
output_mathtypes={
output_name: (
spux.MathType.Real
if output_mathtype == spux.MathType.Complex
else output_mathtype
)
for output_name, output_mathtype in info.output_mathtypes.items()
},
output_units=info.output_units,
)
return info
@events.computes_output_socket(
'Data',

View File

@ -2,8 +2,11 @@ import enum
import typing as typ
import bpy
import jaxtyping as jtyp
import matplotlib.axis as mpl_ax
from blender_maxwell.utils import bl_cache, image_ops, logger
from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct
from ... import managed_objs, sockets
@ -12,9 +15,168 @@ from .. import base, events
log = logger.get(__name__)
class VizMode(enum.StrEnum):
"""Available visualization modes.
**NOTE**: >1D output dimensions currently have no viz.
Plots for `() -> `:
- Hist1D: Bin-summed distribution.
- BoxPlot1D: Box-plot describing the distribution.
Plots for `() -> `:
- BoxPlots1D: Side-by-side boxplots w/equal y axis.
Plots for `() -> `:
- Curve2D: Standard line-curve w/smooth interpolation
- Points2D: Scatterplot of individual points.
- Bar: Value to height of a barplot.
Plots for `(, ) -> `:
- Curves2D: Layered Curve2Ds with unique colors.
- FilledCurves2D: Layered Curve2Ds with filled space between.
Plots for `(, ) -> `:
- Heatmap2D: Colormapped image with value at each pixel.
Plots for `(, , ) -> `:
- SqueezedHeatmap2D: 3D-embeddable heatmap for when one of the axes is 1.
- Heatmap3D: Colormapped field with value at each voxel.
"""
Hist1D = enum.auto()
BoxPlot1D = enum.auto()
Curve2D = enum.auto()
Points2D = enum.auto()
Bar = enum.auto()
Curves2D = enum.auto()
FilledCurves2D = enum.auto()
Heatmap2D = enum.auto()
SqueezedHeatmap2D = enum.auto()
Heatmap3D = enum.auto()
@staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
valid_viz_modes = {
((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D],
((spux.MathType.Integer), (spux.MathType.Real)): [
VizMode.Hist1D,
VizMode.BoxPlot1D,
],
((spux.MathType.Real,), (spux.MathType.Real,)): [
VizMode.Curve2D,
VizMode.Points2D,
VizMode.Bar,
],
((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [
VizMode.Curves2D,
VizMode.FilledCurves2D,
],
((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [
VizMode.Heatmap2D,
],
(
(spux.MathType.Real, spux.MathType.Real, spux.MathType.Real),
(spux.MathType.Real,),
): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D],
}.get(
(
tuple(info.dim_mathtypes.values()),
tuple(info.output_mathtypes.values()),
)
)
if valid_viz_modes is None:
return []
return valid_viz_modes
@staticmethod
def to_plotter(
value: typ.Self,
) -> typ.Callable[
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]:
return {
VizMode.Hist1D: image_ops.plot_hist_1d,
VizMode.BoxPlot1D: image_ops.plot_box_plot_1d,
VizMode.Curve2D: image_ops.plot_curve_2d,
VizMode.Points2D: image_ops.plot_points_2d,
VizMode.Bar: image_ops.plot_bar,
VizMode.Curves2D: image_ops.plot_curves_2d,
VizMode.FilledCurves2D: image_ops.plot_filled_curves_2d,
VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
# NO PLOTTER: VizMode.SqueezedHeatmap2D
# NO PLOTTER: VizMode.Heatmap3D
}[value]
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizMode.Hist1D: 'Histogram',
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
VizMode.Bar: 'Bar',
VizMode.Curves2D: 'Curves',
VizMode.FilledCurves2D: 'Filled Curves',
VizMode.Heatmap2D: 'Heatmap',
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
VizMode.Heatmap3D: 'Heatmap (3D)',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
class VizTarget(enum.StrEnum):
"""Available visualization targets."""
Plot2D = enum.auto()
Pixels = enum.auto()
PixelsPlane = enum.auto()
Voxels = enum.auto()
@staticmethod
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
return {
'NONE': [],
VizMode.Hist1D: [VizTarget.Plot2D],
VizMode.BoxPlot1D: [VizTarget.Plot2D],
VizMode.Curve2D: [VizTarget.Plot2D],
VizMode.Points2D: [VizTarget.Plot2D],
VizMode.Bar: [VizTarget.Plot2D],
VizMode.Curves2D: [VizTarget.Plot2D],
VizMode.FilledCurves2D: [VizTarget.Plot2D],
VizMode.Heatmap2D: [VizTarget.Plot2D, VizTarget.Pixels],
VizMode.SqueezedHeatmap2D: [VizTarget.Pixels, VizTarget.PixelsPlane],
VizMode.Heatmap3D: [VizTarget.Voxels],
}[viz_mode]
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizTarget.Plot2D: 'Image (Plot)',
VizTarget.Pixels: 'Image (Pixels)',
VizTarget.PixelsPlane: 'Image (Plane)',
VizTarget.Voxels: '3D Field',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
class VizNode(base.MaxwellSimNode):
"""Node for visualizing simulation data, by querying its monitors.
Auto-detects the correct plot type based on the input data:
Attributes:
colormap: Colormap to apply to 0..1 output.
@ -40,24 +202,91 @@ class VizNode(base.MaxwellSimNode):
#####################
## - Properties
#####################
viz_mode: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_modes()
)
viz_target: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_targets()
)
# Mode-Dependent Properties
colormap: image_ops.Colormap = bl_cache.BLField(
image_ops.Colormap.Viridis, prop_ui=True
)
#####################
## - Mode Searcher
#####################
@property
def _info(self) -> ct.InfoFlow:
return self._compute_input('Data', kind=ct.FlowKind.Info)
def search_modes(self) -> list[ct.BLEnumElement]:
info = self._info
return [
(
viz_mode,
VizMode.to_name(viz_mode),
VizMode.to_name(viz_mode),
VizMode.to_icon(viz_mode),
i,
)
for i, viz_mode in enumerate(VizMode.valid_modes_for(info))
]
#####################
## - Target Searcher
#####################
def search_targets(self) -> list[ct.BLEnumElement]:
return [
(
viz_target,
VizTarget.to_name(viz_target),
VizTarget.to_name(viz_target),
VizTarget.to_icon(viz_target),
i,
)
for i, viz_target in enumerate(VizTarget.valid_targets_for(self.viz_mode))
]
#####################
## - UI
#####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, self.blfields['viz_mode'], text='')
col.prop(self, self.blfields['viz_target'], text='')
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
col.prop(self, self.blfields['colormap'], text='')
####################
# - Events
####################
@events.on_value_changed(
socket_name='Data',
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
input_sockets_optional={'Data': True},
run_on_init=True,
)
def on_socket_set_changed(self, input_sockets: dict):
self.viz_mode = bl_cache.Signal.ResetEnumItems
self.viz_target = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
prop_name='viz_mode',
# run_on_init=True,
)
def on_viz_mode_changed(self):
self.viz_target = bl_cache.Signal.ResetEnumItems
#####################
## - Plotting
#####################
@events.on_show_plot(
managed_objs={'plot'},
props={'colormap'},
props={'viz_mode', 'viz_target', 'colormap'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Array},
input_socket_kinds={'Data': {ct.FlowKind.Array, ct.FlowKind.Info}},
input_sockets_optional={'Data': True},
stop_propagation=True,
)
@ -67,13 +296,32 @@ class VizNode(base.MaxwellSimNode):
input_sockets: dict,
props: dict,
):
if input_sockets['Data'] is not None:
array_flow = input_sockets['Data'][ct.FlowKind.Array]
info = input_sockets['Data'][ct.FlowKind.Info]
if input_sockets['Data'] is None:
return
if props['viz_target'] == VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(
array_flow.values, info, ax
),
bl_select=True,
)
if props['viz_target'] == VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image(
input_sockets['Data'].values,
array_flow.values,
colormap=props['colormap'],
bl_select=True,
)
if props['viz_target'] == VizTarget.PixelsPlane:
raise NotImplementedError
if props['viz_target'] == VizTarget.Voxels:
raise NotImplementedError
####################
# - Blender Registration

View File

@ -843,10 +843,12 @@ class MaxwellSimNode(bpy.types.Node):
# Propagate Event to All Sockets in "Trigger Direction"
## The trigger chain goes node/socket/node/socket/...
if not stop_propagation:
triggered_sockets = self._bl_sockets(
direc=ct.FlowEvent.flow_direction[event]
)
direc = ct.FlowEvent.flow_direction[event]
triggered_sockets = self._bl_sockets(direc=direc)
for bl_socket in triggered_sockets:
if direc == 'output' and not bl_socket.is_linked:
continue
# log.critical(
# '![%s] Propagating: (%s, %s)',
# self.sim_node_name,

View File

@ -505,7 +505,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
socket_name=self.name,
socket_kinds=socket_kinds,
)
else:
self.node.trigger_event(
event, socket_name=self.name, socket_kinds=socket_kinds
)

View File

@ -49,14 +49,14 @@ class DataBLSocket(base.MaxwellSimSocket):
columns=3,
row_major=True,
even_columns=True,
#even_rows=True,
# even_rows=True,
align=True,
)
# Grid Header
#grid.label(text='Dim')
#grid.label(text='Len')
#grid.label(text='Unit')
# grid.label(text='Dim')
# grid.label(text='Len')
# grid.label(text='Unit')
# Dimension Names
for dim_name in info.dim_names:

View File

@ -828,7 +828,6 @@ class BLField:
current_items = self._enum_cb(bl_instance, None)
# Only Change if Changes Need Making
## <Proverb. 'Fun things to say in jail'. Ca 887BCE. >
if old_items != current_items:
# Set Enum to First Item
## Prevents the seemingly "missing" enum element bug.
@ -837,7 +836,7 @@ class BLField:
## -> Infinite recursion if we don't check current value.
## -> May cause a hiccup (chains will trigger twice)
## To work, there **must** be a guaranteed-available string at 0,0.
first_old_value = old_items(bl_instance, None)[0][0]
first_old_value = old_items[0][0]
current_value = self._cached_bl_property.__get__(
bl_instance, bl_instance.__class__
)

View File

@ -10,9 +10,11 @@ Attributes:
Should be used via the `ConstrSympyExpr`, which also adds expression validation.
"""
import enum
import itertools
import typing as typ
import jax.numpy as jnp
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
@ -22,6 +24,54 @@ from pydantic_core import core_schema as pyd_core_schema
SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix | spu.Quantity
class MathType(enum.StrEnum):
Bool = enum.auto()
Integer = enum.auto()
Rational = enum.auto()
Real = enum.auto()
Complex = enum.auto()
@staticmethod
def from_expr(sp_obj: SympyType) -> type:
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
return MathType.Bool
if sp_obj.is_integer:
return MathType.Integer
if sp_obj.is_rational or sp_obj.is_real:
return MathType.Real
if sp_obj.is_complex:
return MathType.Complex
msg = "Can't determine MathType from sympy object: {sp_obj}"
raise ValueError(msg)
@staticmethod
def from_pytype(dtype) -> type:
return {
bool: MathType.Bool,
int: MathType.Integer,
float: MathType.Real,
complex: MathType.Complex,
#jnp.int32: MathType.Integer,
#jnp.int64: MathType.Integer,
#jnp.float32: MathType.Real,
#jnp.float64: MathType.Real,
#jnp.complex64: MathType.Complex,
#jnp.complex128: MathType.Complex,
#jnp.bool_: MathType.Bool,
}[dtype]
@staticmethod
def to_dtype(value: typ.Self) -> type:
return {
MathType.Bool: bool,
MathType.Integer: int,
MathType.Rational: float,
MathType.Real: float,
MathType.Complex: complex,
}[value]
####################
# - Units
####################

View File

@ -1,12 +1,14 @@
"""Useful image processing operations for use in the addon."""
import enum
import time
import typing as typ
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import matplotlib
import matplotlib.axis as mpl_ax
from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger
@ -106,3 +108,129 @@ def rgba_image_from_2d_map(
return rgba_image_from_2d_map__grayscale(map_2d)
return rgba_image_from_2d_map__grayscale(map_2d)
####################
# - Plotters
####################
# () ->
def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None:
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# () ->
def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
ax.boxplot(data)
ax.set_title('Box Plot')
ax.set_xlabel(f'{x_name}')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# () ->
def plot_curve_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
times = [time.perf_counter()]
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data)
times.append(time.perf_counter() - times[0])
ax.set_title('2D Curve')
times.append(time.perf_counter() - times[0])
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
times.append(time.perf_counter() - times[0])
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
times.append(time.perf_counter() - times[0])
# log.critical('Timing of Curve2D: %s', str(times))
def plot_points_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
ax.set_title('2D Points')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# (, ) ->
def plot_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
for category in range(data.shape[1]):
ax.plot(data[:, 0], data[:, 1])
ax.set_title('2D Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
ax.legend()
def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
ax.set_title('2D Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# (, ) ->
def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.dim_names[1]
y_unit = info.dim_units[y_name]
heatmap = ax.imshow(data, aspect='auto', interpolation='none')
ax.figure.colorbar(heatmap, ax=ax)
ax.set_title('Heatmap')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))