From dfeb65feec1d567dffa13ac6cb7fc5dd81412621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Wed, 17 Apr 2024 16:03:15 +0200 Subject: [PATCH] feat: Math nodes (non-working) --- pyproject.toml | 6 +- requirements-dev.lock | 8 + requirements.lock | 8 + .../maxwell_sim_nodes/contracts/data_flows.py | 316 ++++++++---------- .../contracts/socket_colors.py | 1 + .../contracts/socket_shapes.py | 1 + .../contracts/socket_types.py | 1 + .../managed_objs/managed_bl_image.py | 28 +- .../nodes/analysis/__init__.py | 4 +- .../nodes/analysis/extract_data.py | 24 +- .../nodes/analysis/math/__init__.py | 14 + .../nodes/analysis/math/filter_math.py | 121 +++++++ .../nodes/analysis/math/map_math.py | 164 +++++++++ .../nodes/analysis/math/operate_math.py | 138 ++++++++ .../nodes/analysis/math/reduce_math.py | 135 ++++++++ .../maxwell_sim_nodes/nodes/analysis/viz.py | 14 +- .../maxwell_sim_nodes/sockets/base.py | 65 +++- .../sockets/basic/__init__.py | 6 +- .../maxwell_sim_nodes/sockets/basic/expr.py | 80 +++++ .../maxwell_sim_nodes/sockets/basic/string.py | 3 + src/blender_maxwell/utils/jarray.py | 63 ++++ 21 files changed, 974 insertions(+), 226 deletions(-) create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py create mode 100644 src/blender_maxwell/utils/jarray.py diff --git a/pyproject.toml b/pyproject.toml index 1911238..af27e37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,10 @@ dependencies = [ "networkx==3.2.*", "rich==12.5.*", "rtree==1.2.*", + "jax[cpu]==0.4.26", + "msgspec[toml]==0.18.6", + "numba==0.59.1", + "jaxtyping==0.2.28", # Pin Blender 4.1.0-Compatible Versions ## The dependency resolver will report if anything is wonky. "urllib3==1.26.8", @@ -22,8 +26,6 @@ dependencies = [ "idna==3.3", "charset-normalizer==2.0.10", "certifi==2021.10.8", - "jax[cpu]>=0.4.26", - "msgspec[toml]>=0.18.6", ] readme = "README.md" requires-python = "~= 3.11" diff --git a/requirements-dev.lock b/requirements-dev.lock index dca0820..057e4f6 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -49,11 +49,14 @@ importlib-metadata==6.11.0 jax==0.4.26 jaxlib==0.4.26 # via jax +jaxtyping==0.2.28 jmespath==1.0.1 # via boto3 # via botocore kiwisolver==1.4.5 # via matplotlib +llvmlite==0.42.0 + # via numba locket==1.0.0 # via partd matplotlib==3.8.3 @@ -65,13 +68,16 @@ mpmath==1.3.0 # via sympy msgspec==0.18.6 networkx==3.2 +numba==0.59.1 numpy==1.24.3 # via contourpy # via h5py # via jax # via jaxlib + # via jaxtyping # via matplotlib # via ml-dtypes + # via numba # via opt-einsum # via scipy # via shapely @@ -142,6 +148,8 @@ toolz==0.12.1 # via dask # via partd trimesh==4.2.0 +typeguard==2.13.3 + # via jaxtyping types-pyyaml==6.0.12.20240311 # via responses typing-extensions==4.10.0 diff --git a/requirements.lock b/requirements.lock index 367c92c..7ace541 100644 --- a/requirements.lock +++ b/requirements.lock @@ -48,11 +48,14 @@ importlib-metadata==6.11.0 jax==0.4.26 jaxlib==0.4.26 # via jax +jaxtyping==0.2.28 jmespath==1.0.1 # via boto3 # via botocore kiwisolver==1.4.5 # via matplotlib +llvmlite==0.42.0 + # via numba locket==1.0.0 # via partd matplotlib==3.8.3 @@ -64,13 +67,16 @@ mpmath==1.3.0 # via sympy msgspec==0.18.6 networkx==3.2 +numba==0.59.1 numpy==1.24.3 # via contourpy # via h5py # via jax # via jaxlib + # via jaxtyping # via matplotlib # via ml-dtypes + # via numba # via opt-einsum # via scipy # via shapely @@ -140,6 +146,8 @@ toolz==0.12.1 # via dask # via partd trimesh==4.2.0 +typeguard==2.13.3 + # via jaxtyping types-pyyaml==6.0.12.20240311 # via responses typing-extensions==4.10.0 diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/data_flows.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/data_flows.py index d912cc6..8b96029 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/data_flows.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/data_flows.py @@ -4,8 +4,9 @@ import functools import typing as typ from types import MappingProxyType -# import colour ## TODO -import numpy as np +import jax +import jax.numpy as jnp +import numba import sympy as sp import sympy.physics.units as spu import typing_extensions as typx @@ -15,66 +16,46 @@ from ....utils import sci_constants as constants from .socket_types import SocketType -class DataFlowKind(enum.StrEnum): - """Defines a shape/kind of data that may flow through a node tree. +class FlowKind(enum.StrEnum): + """Defines a kind of data that can flow between nodes. - Since a node socket may define one of each, we can support several related kinds of data flow through the same node-graph infrastructure. + Each node link can be thought to contain **multiple pipelines for data to flow along**. + Each pipeline is cached incrementally, and independently, of the others. + Thus, the same socket can easily support several kinds of related data flow at the same time. Attributes: - Value: A value without any unknown symbols. - - Basic types aka. float, int, list, string, etc. . - - Exotic (immutable-ish) types aka. numpy array, KDTree, etc. . - - A usable constructed object, ex. a `tidy3d.Box`. - - Expressions (`sp.Expr`) that don't have unknown variables. - - Lazy sequences aka. generators, with all data bound. - SpectralValue: A value defined along a spectral range. - - {`np.array` - - LazyValue: An object which, when given new data, can make many values. - - An `sp.Expr`, which might need `simplify`ing, `jax` JIT'ing, unit cancellations, variable substitutions, etc. before use. - - Lazy objects, for which all parameters aren't yet known. - - A computational graph aka. `aesara`, which may even need to be handled before - - Capabilities: A `ValueCapability` object providing compatibility. - - # Value Data Flow - Simply passing values is the simplest and easiest use case. - - This doesn't mean it's "dumb" - ex. a `sp.Expr` might, before use, have `simplify`, rewriting, unit cancellation, etc. run. - All of this is okay, as long as there is no *introduction of new data* ex. variable substitutions. - - - # Lazy Value Data Flow - By passing (essentially) functions, one supports: - - **Lightness**: While lazy values can be made expensive to construct, they will generally not be nearly as heavy to handle when trying to work with ex. operations on voxel arrays. - - **Performance**: Parameterizing ex. `sp.Expr` with variables allows one to build very optimized functions, which can make ex. node graph updates very fast if the only operation run is the `jax` JIT'ed function (aka. GPU accelerated) generated from the final full expression. - - **Numerical Stability**: Libraries like `aesara` build a computational graph, which can be automatically rewritten to avoid many obvious conditioning / cancellation errors. - - **Lazy Output**: The goal of a node-graph may not be the definition of a single value, but rather, a parameterized expression for generating *many values* with known properties. This is especially interesting for use cases where one wishes to build an optimization step using nodes. - - - # Capability Passing - By being able to pass "capabilities" next to other kinds of values, nodes can quickly determine whether a given link is valid without having to actually compute it. - - - # Lazy Parameter Value - When using parameterized LazyValues, one may wish to independently pass parameter values through the graph, so they can be inserted into the final (cached) high-performance expression without. - - The advantage of using a different data flow would be changing this kind of value would ONLY invalidate lazy parameter value caches, which would allow an incredibly fast path of getting the value into the lazy expression for high-performance computation. - - Implementation TBD - though, ostensibly, one would have a "parameter" node which both would only provide a LazyValue (aka. a symbolic variable), but would also be able to provide a LazyParamValue, which would be a particular value of some kind (probably via the `value` of some other node socket). + Capabilities: Describes a socket's linkeability with other sockets. + Links between sockets with incompatible capabilities will be rejected. + This doesn't need to be defined normally, as there is a default. + However, in some cases, defining it manually to control linkeability more granularly may be desirable. + Value: A generic object, which is "directly usable". + This should be chosen when a more specific flow kind doesn't apply. + Array: An object with dimensions, and possibly a unit. + Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array` + However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually. + LazyValueFunc: A composable function. + Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance. + LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing). + This should be used instead of `Array` whenever possible. + Param: An object providing data to complete `Lazy` data. + For example, + Info: An object providing context about other flows. + For example, """ Capabilities = enum.auto() # Values Value = enum.auto() - ValueArray = enum.auto() - ValueSpectrum = enum.auto() + Array = enum.auto() # Lazy LazyValue = enum.auto() - LazyValueRange = enum.auto() - LazyValueSpectrum = enum.auto() + LazyArrayRange = enum.auto() + + # Auxiliary + Param = enum.auto() + Info = enum.auto() @classmethod def scale_to_unit_system(cls, kind: typ.Self, value, socket_type, unit_system): @@ -85,7 +66,7 @@ class DataFlowKind(enum.StrEnum): unit_system[socket_type], ) ) - if kind == cls.LazyValueRange: + if kind == cls.LazyArrayRange: return value.rescale_to_unit(unit_system[socket_type]) msg = 'Tried to scale unknown kind' @@ -93,12 +74,12 @@ class DataFlowKind(enum.StrEnum): #################### -# - Data Structures: Capabilities +# - Capabilities #################### @dataclasses.dataclass(frozen=True, kw_only=True) -class DataCapabilities: +class CapabilitiesFlow: socket_type: SocketType - active_kind: DataFlowKind + active_kind: FlowKind is_universal: bool = False @@ -110,13 +91,16 @@ class DataCapabilities: #################### -# - Data Structures: Non-Lazy +# - Value #################### -DataValue: typ.TypeAlias = typ.Any +ValueFlow: typ.TypeAlias = typ.Any +#################### +# - Value Array +#################### @dataclasses.dataclass(frozen=True, kw_only=True) -class DataValueArray: +class ArrayFlow: """A simple, flat array of values with an optionally-attached unit. Attributes: @@ -125,69 +109,105 @@ class DataValueArray: None if unitless. """ - values: typ.Sequence[DataValue] + values: jax.Array unit: spu.Quantity | None +#################### +# - Lazy Value Func +#################### +LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], ValueFlow] + + @dataclasses.dataclass(frozen=True, kw_only=True) -class DataValueSpectrum: - """A numerical representation of a spectral distribution. +class LazyValueFuncFlow: + r"""Encapsulates a lazily evaluated data value as a composable function with bound and free arguments. + + - **Bound Args**: Arguments that are realized when **defining** the lazy value. + Both positional values and keyword values are supported. + - **Free Args**: Arguments that are specified when evaluating the lazy value. + Both positional values and keyword values are supported. + + The **root function** is encapsulated using `from_function`, and must accept arguments in the following order: + + $$ + f_0:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \underbrace{r_1, r_2, ...}_{\text{Free}}) \to \text{output}_0 + $$ + + Subsequent **composed functions** are encapsulated from the _root function_, and are created with `root_function.compose`. + They must accept arguments in the following order: + + $$ + f_k:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \text{output}_{k-1} ,\ \underbrace{r_p, r_{p+1}, ...}_{\text{Free}}) \to \text{output}_k + $$ Attributes: - wls: A 1D `numpy` float array of wavelength values. - wls_unit: The unit of wavelengths, as length dimension. - values: A 1D `numpy` float array of values corresponding to wavelength values. - values_unit: The unit of the value, as arbitrary dimension. - freqs_unit: The unit of the value, as arbitrary dimension. + function: The function to be lazily evaluated. + bound_args: Arguments that will be packaged into function, which can't be later modifier. + func_kwargs: Arguments to be specified by the user at the time of use. + supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler. + supports_numba: Whether the contained `self.function` can be compiled with Numba's JIT compiler. """ - # Wavelength - wls: np.array - wls_unit: spu.Quantity + func: LazyFunction + func_kwargs: dict[str, type] + supports_jax: bool = False + supports_numba: bool = False - # Value - values: np.array - values_unit: spu.Quantity + @staticmethod + def from_func( + func: LazyFunction, + supports_jax: bool = False, + supports_numba: bool = False, + **func_kwargs: dict[str, type], + ) -> typ.Self: + return LazyValueFuncFlow( + func=func, + func_kwargs=func_kwargs, + supports_jax=supports_jax, + supports_numba=supports_numba, + ) - # Frequency - freqs_unit: spu.Quantity = spu.hertz + # Composition + def compose_within( + self, + enclosing_func: LazyFunction, + supports_jax: bool = False, + supports_numba: bool = False, + **enclosing_func_kwargs: dict[str, type], + ) -> typ.Self: + return LazyValueFuncFlow( + function=lambda **kwargs: enclosing_func( + self.func(**{k: v for k, v in kwargs if k in self.func_kwargs}), + **kwargs, + ), + func_kwargs=self.func_kwargs | enclosing_func_kwargs, + supports_jax=self.supports_jax and supports_jax, + supports_numba=self.supports_numba and supports_numba, + ) @functools.cached_property - def freqs(self) -> np.array: - """The spectral frequencies, computed from the wavelengths. + def func_jax(self) -> LazyFunction: + if self.supports_jax: + return jax.jit(self.func) - Frequencies are NOT reversed, so as to preserve the by-index mapping to `DataValueSpectrum.values`. + msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False' + raise ValueError(msg) - Returns: - Frequencies, as a unitless `numpy` array. - Use `DataValueSpectrum.wls_unit` to interpret this return value. - """ - unitless_speed_of_light = spux.sympy_to_python( - spux.scale_to_unit( - constants.vac_speed_of_light, (self.wl_unit / self.freq_unit) - ) - ) - return unitless_speed_of_light / self.wls + @functools.cached_property + def func_numba(self) -> LazyFunction: + if self.supports_numba: + return numba.jit(self.func) - # TODO: Colour Library - # def as_colour_sd(self) -> colour.SpectralDistribution: - # """Returns the `colour` representation of this spectral distribution, ideal for plotting and colorimetric analysis.""" - # return colour.SpectralDistribution(data=self.values, domain=self.wls) + msg = 'Can\'t express LazyValueFuncFlow as Numba function (using numba.jit), since "self.supports_numba" is False' + raise ValueError(msg) #################### -# - Data Structures: Lazy +# - Lazy Array Range #################### @dataclasses.dataclass(frozen=True, kw_only=True) -class LazyDataValue: - callback: typ.Callable[[...], [DataValue]] - - def realize(self, *args: list[DataValue]) -> DataValue: - return self.callback(*args) - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class LazyDataValueRange: +class LazyArrayRangeFlow: symbols: set[sp.Symbol] start: sp.Basic @@ -200,7 +220,7 @@ class LazyDataValueRange: def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self: if self.has_unit: - return LazyDataValueRange( + return LazyArrayRangeFlow( symbols=self.symbols, has_unit=self.has_unit, unit=unit, @@ -219,7 +239,7 @@ class LazyDataValueRange: reverse: bool = False, ) -> typ.Self: """Call a function on both bounds (start and stop), creating a new `LazyDataValueRange`.""" - return LazyDataValueRange( + return LazyArrayRangeFlow( symbols=self.symbols, has_unit=self.has_unit, unit=self.unit, @@ -234,8 +254,8 @@ class LazyDataValueRange: ) def realize( - self, symbol_values: dict[sp.Symbol, DataValue] = MappingProxyType({}) - ) -> DataValueArray: + self, symbol_values: dict[sp.Symbol, ValueFlow] = MappingProxyType({}) + ) -> ArrayFlow: # Realize Symbols if not self.has_unit: start = spux.sympy_to_python(self.start.subs(symbol_values)) @@ -250,85 +270,25 @@ class LazyDataValueRange: # Return Linspace / Logspace if self.scaling == 'lin': - return DataValueArray( - values=np.linspace(start, stop, self.steps), unit=self.unit + return ArrayFlow( + values=jnp.linspace(start, stop, self.steps), unit=self.unit ) if self.scaling == 'geom': - return DataValueArray(np.geomspace(start, stop, self.steps), self.unit) + return ArrayFlow(jnp.geomspace(start, stop, self.steps), self.unit) if self.scaling == 'log': - return DataValueArray(np.logspace(start, stop, self.steps), self.unit) + return ArrayFlow(jnp.logspace(start, stop, self.steps), self.unit) - raise NotImplementedError + msg = f'ArrayFlow scaling method {self.scaling} is unsupported' + raise RuntimeError(msg) -@dataclasses.dataclass(frozen=True, kw_only=True) -class LazyDataValueSpectrum: - wl_unit: spu.Quantity - value_unit: spu.Quantity - value_expr: sp.Expr - - symbols: tuple[sp.Symbol, ...] = () - freq_symbol: sp.Symbol = sp.Symbol('lamda') # noqa: RUF009 - - def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self: - raise NotImplementedError - - @functools.cached_property - def as_func(self) -> typ.Callable[[DataValue, ...], DataValue]: - """Generates an optimized function for numerical evaluation of the spectral expression.""" - return sp.lambdify([self.freq_symbol, *self.symbols], self.value_expr) - - def realize( - self, wl_range: DataValueArray, symbol_values: tuple[DataValue, ...] - ) -> DataValueSpectrum: - r"""Realizes the parameterized spectral function as a numerical spectral distribution. - - Parameters: - wl_range: The lazy wavelength range to build the concrete spectral distribution with. - symbol_values: Numerical values for each symbol, in the same order as defined in `LazyDataValueSpectrum.symbols`. - The wavelength symbol ($\lambda$ by default) always goes first. - _This is used to call the spectral function using the output of `.as_func()`._ - - Returns: - The concrete, numerical spectral distribution. - """ - return DataValueSpectrum( - wls=wl_range.values, - wls_unit=self.wl_unit, - values=self.as_func(*list(symbol_values.values())), - values_unit=self.value_unit, - ) +#################### +# - Param +#################### +ParamFlow: typ.TypeAlias = dict[str, typ.Any] -# -# -##################### -## - Data Pipeline -##################### -# @dataclasses.dataclass(frozen=True, kw_only=True) -# class DataPipelineDim: -# unit: spu.Quantity | None -# -# class DataPipelineDimType(enum.StrEnum): -# # Map Inputs -# Time = enum.auto() -# Freq = enum.auto() -# Space3D = enum.auto() -# DiffOrder = enum.auto() -# -# # Map Inputs -# Power = enum.auto() -# EVec = enum.auto() -# HVec = enum.auto() -# RelPerm = enum.auto() -# -# -# @dataclasses.dataclass(frozen=True, kw_only=True) -# class LazyDataPipeline: -# dims: list[DataPipelineDim] -# -# def _callable(self): -# """JITs the current pipeline of functions with `jax`.""" -# -# def __call__(self): -# pass +#################### +# - Lazy Value Func +#################### +InfoFlow: typ.TypeAlias = dict[str, typ.Any] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py index dab36ea..f85e511 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py @@ -7,6 +7,7 @@ SOCKET_COLORS = { ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey + ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey # Number ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_shapes.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_shapes.py index 15f39e1..a2f10a0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_shapes.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_shapes.py @@ -6,6 +6,7 @@ SOCKET_SHAPES = { ST.Bool: 'CIRCLE', ST.String: 'CIRCLE', ST.FilePath: 'CIRCLE', + ST.Expr: 'CIRCLE', # Number ST.IntegerNumber: 'CIRCLE', ST.RationalNumber: 'CIRCLE', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py index 5225e45..82f490e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py @@ -14,6 +14,7 @@ class SocketType(BlenderTypeEnum): String = enum.auto() FilePath = enum.auto() Color = enum.auto() + Expr = enum.auto() # Number IntegerNumber = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index e9d6141..2bea3fd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -38,8 +38,8 @@ def apply_colormap(normalized_data, colormap): @jax.jit -def rgba_image_from_xyzf__viridis(xyz_freq): - amplitude = jnp.abs(jnp.squeeze(xyz_freq)) +def rgba_image_from_2d_map__viridis(map_2d): + amplitude = jnp.abs(map_2d) amplitude_normalized = (amplitude - amplitude.min()) / ( amplitude.max() - amplitude.min() ) @@ -49,8 +49,8 @@ def rgba_image_from_xyzf__viridis(xyz_freq): @jax.jit -def rgba_image_from_xyzf__grayscale(xyz_freq): - amplitude = jnp.abs(jnp.squeeze(xyz_freq)) +def rgba_image_from_2d_map__grayscale(map_2d): + amplitude = jnp.abs(map_2d) amplitude_normalized = (amplitude - amplitude.min()) / ( amplitude.max() - amplitude.min() ) @@ -59,21 +59,19 @@ def rgba_image_from_xyzf__grayscale(xyz_freq): return jnp.dstack((rgb_array, alpha_channel)) -def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None): - """RGBA Image from Squeezable XYZ-Freq w/fixed freq. +def rgba_image_from_2d_map(map_2d, colormap: str | None = None): + """RGBA Image from a map of 2D coordinates to values. Parameters: - xyz_freq: Shape (xlen, ylen, zlen), one dimension has length 1. - width_px: Pixel width to resize the image to. - height: Pixel height to resize the image to. + map_2d: Shape (width, height, value). Returns: - Image as a JAX array of shape (height, width, 3) + Image as a JAX array of shape (height, width, 4) """ if colormap == 'VIRIDIS': - return rgba_image_from_xyzf__viridis(xyz_freq) + return rgba_image_from_2d_map__viridis(map_2d) if colormap == 'GRAYSCALE': - return rgba_image_from_xyzf__grayscale(xyz_freq) + return rgba_image_from_2d_map__grayscale(map_2d) class ManagedBLImage(base.ManagedObj): @@ -227,11 +225,11 @@ class ManagedBLImage(base.ManagedObj): #################### # - Special Methods #################### - def xyzf_to_image( - self, xyz_freq, colormap: str | None = 'VIRIDIS', bl_select: bool = False + def map_2d_to_image( + self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False ): self.data_to_image( - lambda _: rgba_image_from_xyzf(xyz_freq, colormap=colormap), + lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap), bl_select=bl_select, ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/__init__.py index 0c6cd6e..17bb0d3 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/__init__.py @@ -1,10 +1,12 @@ -from . import extract_data, viz +from . import extract_data, viz, math BL_REGISTER = [ *extract_data.BL_REGISTER, *viz.BL_REGISTER, + *math.BL_REGISTER, ] BL_NODES = { **extract_data.BL_NODES, **viz.BL_NODES, + **math.BL_NODES, } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index ed0a5a8..8b27c8e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -1,8 +1,11 @@ import typing as typ import bpy +import jax.numpy as jnp +import sympy.physics.units as spu + +from blender_maxwell.utils import jarray, logger -from .....utils import logger from ... import contracts as ct from ... import sockets from .. import base, events @@ -229,8 +232,10 @@ class ExtractDataNode(base.MaxwellSimNode): @events.computes_output_socket( 'Data', props={'sim_data__monitor_name', 'field_data__component'}, + input_sockets={'Field Data'}, + input_sockets_optional={'Field Data': True}, ) - def compute_extracted_data(self, props: dict): + def compute_extracted_data(self, props: dict, input_sockets: dict): if self.active_socket_set == 'Sim Data': if ( CACHE_SIM_DATA.get(self.instance_id) is None @@ -242,12 +247,21 @@ class ExtractDataNode(base.MaxwellSimNode): return sim_data.monitor_data[props['sim_data__monitor_name']] elif self.active_socket_set == 'Field Data': # noqa: RET505 - field_data = self._compute_input('Field Data') - return getattr(field_data, props['field_data__component']) + xarr = getattr(input_sockets['Field Data'], props['field_data__component']) + + return jarray.JArray.from_xarray( + xarr, + dim_units={ + 'x': spu.um, + 'y': spu.um, + 'z': spu.um, + 'f': spu.hertz, + }, + ) elif self.active_socket_set == 'Flux Data': flux_data = self._compute_input('Flux Data') - return flux_data.flux + return jnp.array(flux_data.flux) msg = f'Tried to get data from unknown output socket in "{self.bl_label}"' raise RuntimeError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py new file mode 100644 index 0000000..63f7ce3 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py @@ -0,0 +1,14 @@ +from . import map_math, filter_math, reduce_math, operate_math + +BL_REGISTER = [ + *map_math.BL_REGISTER, + *filter_math.BL_REGISTER, + *reduce_math.BL_REGISTER, + *operate_math.BL_REGISTER, +] +BL_NODES = { + **map_math.BL_NODES, + **filter_math.BL_NODES, + **reduce_math.BL_NODES, + **operate_math.BL_NODES, +} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py new file mode 100644 index 0000000..3867e67 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -0,0 +1,121 @@ +import functools +import typing as typ + +import bpy +import jax +import jax.numpy as jnp + +from blender_maxwell.utils import logger + +from .... import contracts as ct +from .... import sockets +from ... import base, events + +log = logger.get(__name__) + + +# @functools.partial(jax.jit, static_argnames=('fixed_axis', 'fixed_axis_value')) +# jax.jit +def fix_axis(data, fixed_axis: int, fixed_axis_value: float): + log.critical(data.shape) + # Select Values of Fixed Axis + fixed_axis_values = data[ + tuple(slice(None) if i == fixed_axis else 0 for i in range(data.ndim)) + ] + log.critical(fixed_axis_values) + + # Compute Nearest Index on Fixed Axis + idx_of_nearest = jnp.argmin(jnp.abs(fixed_axis_values - fixed_axis_value)) + log.critical(idx_of_nearest) + + # Select Values along Fixed Axis Value + return jnp.take(data, idx_of_nearest, axis=fixed_axis) + + +class FilterMathNode(base.MaxwellSimNode): + node_type = ct.NodeType.FilterMath + bl_label = 'Filter Math' + + input_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + input_socket_sets: typ.ClassVar = { + 'By Axis Value': { + 'Axis': sockets.IntegerNumberSocketDef(), + 'Value': sockets.RealNumberSocketDef(), + }, + 'By Axis': { + 'Axis': sockets.IntegerNumberSocketDef(), + }, + ## TODO: bool arrays for comparison/switching/sparse 0-setting/etc. . + } + output_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + + #################### + # - Properties + #################### + operation: bpy.props.EnumProperty( + name='Op', + description='Operation to reduce the input axis with', + items=lambda self, _: self.search_operations(), + update=lambda self, context: self.sync_prop('operation', context), + ) + + def search_operations(self) -> list[tuple[str, str, str]]: + items = [] + if self.active_socket_set == 'By Axis Value': + items += [ + ('FIX', 'Fix Coordinate', '(*, N, *) -> (*, *)'), + ] + if self.active_socket_set == 'By Axis': + items += [ + ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'), + ] + else: + items += [('NONE', 'None', 'No operations...')] + + return items + + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + if self.active_socket_set != 'Axis Expr': + layout.prop(self, 'operation') + + #################### + # - Compute + #################### + @events.computes_output_socket( + 'Data', + props={'operation', 'active_socket_set'}, + input_sockets={'Data', 'Axis', 'Value'}, + input_sockets_optional={'Axis': True, 'Value': True}, + ) + def compute_data(self, props: dict, input_sockets: dict): + if not hasattr(input_sockets['Data'], 'shape'): + msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)' + raise ValueError(msg) + + # By Axis Value + if props['active_socket_set'] == 'By Axis Value': + if props['operation'] == 'FIX': + return fix_axis( + input_sockets['Data'], input_sockets['Axis'], input_sockets['Value'] + ) + + # By Axis + if props['active_socket_set'] == 'By Axis': + if props['operation'] == 'SQUEEZE': + return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis']) + + msg = 'Operation invalid' + raise ValueError(msg) + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + FilterMathNode, +] +BL_NODES = {ct.NodeType.FilterMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py new file mode 100644 index 0000000..fd1d6a2 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -0,0 +1,164 @@ +import typing as typ + +import bpy +import jax +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger + +from .... import contracts as ct +from .... import sockets +from ... import base, events + +log = logger.get(__name__) + + +class MapMathNode(base.MaxwellSimNode): + node_type = ct.NodeType.MapMath + bl_label = 'Map Math' + + input_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + input_socket_sets: typ.ClassVar = { + 'By Element': {}, + 'By Vector': {}, + 'By Matrix': {}, + 'Expr': { + 'Mapper': sockets.ExprSocketDef( + symbols=[sp.Symbol('x')], + default_expr=sp.Symbol('x'), + ), + }, + } + output_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + + #################### + # - Properties + #################### + operation: bpy.props.EnumProperty( + name='Op', + description='Operation to apply to the input', + items=lambda self, _: self.search_operations(), + update=lambda self, context: self.sync_prop('operation', context), + ) + + def search_operations(self) -> list[tuple[str, str, str]]: + items = [] + if self.active_socket_set == 'By Element': + items += [ + # General + ('REAL', 'real', 'ℝ(L) (by el)'), + ('IMAG', 'imag', 'Im(L) (by el)'), + ('ABS', 'abs', '|L| (by el)'), + ('SQ', 'square', 'L^2 (by el)'), + ('SQRT', 'sqrt', 'sqrt(L) (by el)'), + ('INV_SQRT', '1/sqrt', '1/sqrt(L) (by el)'), + # Trigonometry + ('COS', 'cos', 'cos(L) (by el)'), + ('SIN', 'sin', 'sin(L) (by el)'), + ('TAN', 'tan', 'tan(L) (by el)'), + ('ACOS', 'acos', 'acos(L) (by el)'), + ('ASIN', 'asin', 'asin(L) (by el)'), + ('ATAN', 'atan', 'atan(L) (by el)'), + ] + elif self.active_socket_set in 'By Vector': + items += [ + ('NORM_2', '2-Norm', '||L||_2 (by Vec)'), + ] + elif self.active_socket_set == 'By Matrix': + items += [ + # Matrix -> Number + ('DET', 'Determinant', 'det(L) (by Mat)'), + ('COND', 'Condition', 'κ(L) (by Mat)'), + ('NORM_FRO', 'Frobenius Norm', '||L||_F (by Mat)'), + ('RANK', 'Rank', 'rank(L) (by Mat)'), + # Matrix -> Array + ('DIAG', 'Diagonal', 'diag(L) (by Mat)'), + ('EIG_VALS', 'Eigenvalues', 'eigvals(L) (by Mat)'), + ('SVD_VALS', 'SVD', 'svd(L) -> diag(Σ) (by Mat)'), + # Matrix -> Matrix + ('INV', 'Invert', 'L^(-1) (by Mat)'), + ('TRA', 'Transpose', 'L^T (by Mat)'), + # Matrix -> Matrices + ('QR', 'QR', 'L -> Q·R (by Mat)'), + ('CHOL', 'Cholesky', 'L -> L·Lh (by Mat)'), + ('SVD', 'SVD', 'L -> U·Σ·Vh (by Mat)'), + ] + else: + items += ['EXPR_EL', 'Expr (by el)', 'Expression-defined (by el)'] + return items + + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + if self.active_socket_set not in {'Expr (Element)'}: + layout.prop(self, 'operation') + + #################### + # - Compute + #################### + @events.computes_output_socket( + 'Data', + props={'active_socket_set', 'operation'}, + input_sockets={'Data', 'Mapper'}, + input_socket_kinds={'Mapper': ct.DataFlowKind.LazyValue}, + input_sockets_optional={'Mapper': True}, + ) + def compute_data(self, props: dict, input_sockets: dict): + mapping_func: typ.Callable[[jax.Array], jax.Array] = { + 'By Element': { + 'REAL': lambda data: jnp.real(data), + 'IMAG': lambda data: jnp.imag(data), + 'ABS': lambda data: jnp.abs(data), + 'SQ': lambda data: jnp.square(data), + 'SQRT': lambda data: jnp.sqrt(data), + 'INV_SQRT': lambda data: 1 / jnp.sqrt(data), + 'COS': lambda data: jnp.cos(data), + 'SIN': lambda data: jnp.sin(data), + 'TAN': lambda data: jnp.tan(data), + 'ACOS': lambda data: jnp.acos(data), + 'ASIN': lambda data: jnp.asin(data), + 'ATAN': lambda data: jnp.atan(data), + 'SINC': lambda data: jnp.sinc(data), + }, + 'By Vector': { + 'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1), + }, + 'By Matrix': { + # Matrix -> Number + 'DET': lambda data: jnp.linalg.det(data), + 'COND': lambda data: jnp.linalg.cond(data), + 'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'), + 'RANK': lambda data: jnp.linalg.matrix_rank(data), + # Matrix -> Vec + 'DIAG': lambda data: jnp.diag(data), + 'EIG_VALS': lambda data: jnp.eigvals(data), + 'SVD_VALS': lambda data: jnp.svdvals(data), + # Matrix -> Matrix + 'INV': lambda data: jnp.inv(data), + 'TRA': lambda data: jnp.matrix_transpose(data), + # Matrix -> Matrices + 'QR': lambda data: jnp.inv(data), + 'CHOL': lambda data: jnp.linalg.cholesky(data), + 'SVD': lambda data: jnp.linalg.svd(data), + }, + 'By El (Expr)': { + 'EXPR_EL': lambda data: input_sockets['Mapper'](data), + }, + }[props['active_socket_set']][props['operation']] + + # Compose w/Lazy Root Function Data + return input_sockets['Data'].compose( + function=mapping_func, + ) + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + MapMathNode, +] +BL_NODES = {ct.NodeType.MapMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py new file mode 100644 index 0000000..4b26d63 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -0,0 +1,138 @@ +import typing as typ + +import bpy +import jax.numpy as jnp + +from blender_maxwell.utils import logger + +from .... import contracts as ct +from .... import sockets +from ... import base, events + +log = logger.get(__name__) + + +class OperateMathNode(base.MaxwellSimNode): + node_type = ct.NodeType.OperateMath + bl_label = 'Operate Math' + + input_socket_sets: typ.ClassVar = { + 'Elementwise': { + 'Data L': sockets.AnySocketDef(), + 'Data R': sockets.AnySocketDef(), + }, + ## TODO: Filter-array building operations + 'Vec-Vec': { + 'Data L': sockets.AnySocketDef(), + 'Data R': sockets.AnySocketDef(), + }, + 'Mat-Vec': { + 'Data L': sockets.AnySocketDef(), + 'Data R': sockets.AnySocketDef(), + }, + } + output_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + + #################### + # - Properties + #################### + operation: bpy.props.EnumProperty( + name='Op', + description='Operation to apply to the two inputs', + items=lambda self, _: self.search_operations(), + update=lambda self, context: self.sync_prop('operation', context), + ) + + def search_operations(self) -> list[tuple[str, str, str]]: + items = [] + if self.active_socket_set == 'Elementwise': + items = [ + ('ADD', 'Add', 'L + R (by el)'), + ('SUB', 'Subtract', 'L - R (by el)'), + ('MUL', 'Multiply', 'L · R (by el)'), + ('DIV', 'Divide', 'L ÷ R (by el)'), + ('POW', 'Power', 'L^R (by el)'), + ('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'), + ('ATAN2', 'atan2', 'atan2(L,R) (by el)'), + ('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'), + ] + elif self.active_socket_set in 'Vec | Vec': + items = [ + ('DOT', 'Dot', 'L · R'), + ('CROSS', 'Cross', 'L x R (by last-axis'), + ] + elif self.active_socket_set == 'Mat | Vec': + items = [ + ('DOT', 'Dot', 'L · R'), + ('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'), + ('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'), + ] + return items + + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + layout.prop(self, 'operation') + + #################### + # - Properties + #################### + @events.computes_output_socket( + 'Data', + props={'operation'}, + input_sockets={'Data L', 'Data R'}, + ) + def compute_data(self, props: dict, input_sockets: dict): + if self.active_socket_set == 'Elementwise': + # Element-Wise Arithmetic + if props['operation'] == 'ADD': + return input_sockets['Data L'] + input_sockets['Data R'] + if props['operation'] == 'SUB': + return input_sockets['Data L'] - input_sockets['Data R'] + if props['operation'] == 'MUL': + return input_sockets['Data L'] * input_sockets['Data R'] + if props['operation'] == 'DIV': + return input_sockets['Data L'] / input_sockets['Data R'] + + # Element-Wise Arithmetic + if props['operation'] == 'POW': + return input_sockets['Data L'] ** input_sockets['Data R'] + + # Binary Trigonometry + if props['operation'] == 'ATAN2': + return jnp.atan2(input_sockets['Data L'], input_sockets['Data R']) + + # Special Functions + if props['operation'] == 'HEAVISIDE': + return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R']) + + # Linear Algebra + if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}: + if props['operation'] == 'DOT': + return jnp.dot(input_sockets['Data L'], input_sockets['Data R']) + + elif self.active_socket_set == 'Vec-Vec': + if props['operation'] == 'CROSS': + return jnp.cross(input_sockets['Data L'], input_sockets['Data R']) + + elif self.active_socket_set == 'Mat-Vec': + if props['operation'] == 'LIN_SOLVE': + return jnp.linalg.lstsq( + input_sockets['Data L'], input_sockets['Data R'] + ) + if props['operation'] == 'LSQ_SOLVE': + return jnp.linalg.solve( + input_sockets['Data L'], input_sockets['Data R'] + ) + + msg = 'Invalid operation' + raise ValueError(msg) + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + OperateMathNode, +] +BL_NODES = {ct.NodeType.OperateMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py new file mode 100644 index 0000000..c4f2283 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py @@ -0,0 +1,135 @@ +import typing as typ + +import bpy +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger + +from .... import contracts as ct +from .... import sockets +from ... import base, events + +log = logger.get(__name__) + + +class ReduceMathNode(base.MaxwellSimNode): + node_type = ct.NodeType.ReduceMath + bl_label = 'Reduce Math' + + input_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + 'Axis': sockets.IntegerNumberSocketDef(), + } + input_socket_sets: typ.ClassVar = { + 'By Axis': { + 'Axis': sockets.IntegerNumberSocketDef(), + }, + 'Expr': { + 'Reducer': sockets.ExprSocketDef( + symbols=[sp.Symbol('a'), sp.Symbol('b')], + default_expr=sp.Symbol('a') + sp.Symbol('b'), + ), + 'Axis': sockets.IntegerNumberSocketDef(), + }, + } + output_sockets: typ.ClassVar = { + 'Data': sockets.AnySocketDef(), + } + + #################### + # - Properties + #################### + operation: bpy.props.EnumProperty( + name='Op', + description='Operation to reduce the input axis with', + items=lambda self, _: self.search_operations(), + update=lambda self, context: self.sync_prop('operation', context), + ) + + def search_operations(self) -> list[tuple[str, str, str]]: + items = [] + if self.active_socket_set == 'By Axis': + items += [ + # Accumulation + ('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'), + ('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'), + ('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'), + ('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'), + ('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'), + # Stats + ('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'), + ('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'), + ('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'), + ('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'), + # Dimension Reduction + ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'), + ] + else: + items += [('NONE', 'None', 'No operations...')] + + return items + + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + if self.active_socket_set != 'Axis Expr': + layout.prop(self, 'operation') + + #################### + # - Compute + #################### + @events.computes_output_socket( + 'Data', + props={'operation'}, + input_sockets={'Data', 'Axis', 'Reducer'}, + input_socket_kinds={'Reducer': ct.DataFlowKind.LazyValue}, + input_sockets_optional={'Reducer': True}, + ) + def compute_data(self, props: dict, input_sockets: dict): + if not hasattr(input_sockets['Data'], 'shape'): + msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)' + raise ValueError(msg) + + if self.active_socket_set == 'Axis Expr': + ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1) + return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis']) + + if self.active_socket_set == 'By Axis': + ## Dimension Reduction + # ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'), + # Accumulation + if props['operation'] == 'SUM': + return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'PROD': + return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'MIN': + return jnp.min(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'MAX': + return jnp.max(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'P2P': + return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis']) + + # Stats + if props['operation'] == 'MEAN': + return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'MEDIAN': + return jnp.median(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'STDDEV': + return jnp.std(input_sockets['Data'], axis=input_sockets['Axis']) + if props['operation'] == 'VARIANCE': + return jnp.var(input_sockets['Data'], axis=input_sockets['Axis']) + + # Dimension Reduction + if props['operation'] == 'SQUEEZE': + return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis']) + + msg = 'Operation invalid' + raise ValueError(msg) + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + ReduceMathNode, +] +BL_NODES = {ct.NodeType.ReduceMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index ca17501..ef6013b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -20,7 +20,7 @@ class VizNode(base.MaxwellSimNode): #################### # - Sockets #################### - input_sockets = { + input_sockets: typ.ClassVar = { 'Data': sockets.AnySocketDef(), 'Freq': sockets.PhysicalFreqSocketDef(), } @@ -72,20 +72,12 @@ class VizNode(base.MaxwellSimNode): props: dict, unit_systems: dict, ): - selected_data = jnp.array( - input_sockets['Data'].sel(f=input_sockets['Freq'], method='nearest') - ) - - managed_objs['plot'].xyzf_to_image( - selected_data, + managed_objs['plot'].map_2d_to_image( + input_sockets['Data'].as_bound_jax_func(), colormap=props['colormap'], bl_select=True, ) - # @events.on_init() - # def on_init(self): - # self.on_changed_inputs() - #################### # - Blender Registration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index 8db6141..8006a25 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -42,6 +42,15 @@ class SocketDef(pyd.BaseModel, abc.ABC): # - SocketDef #################### class MaxwellSimSocket(bpy.types.NodeSocket): + """A specialized Blender socket for nodes in a Maxwell simulation. + + Attributes: + instance_id: A unique ID attached to a particular socket instance. + Guaranteed to be unchanged so long as the socket lives. + Used as a socket-specific cache index. + locked: The lock-state of a particular socket, which determines the socket's user editability + """ + # Fundamentals socket_type: ct.SocketType bl_label: str @@ -73,21 +82,53 @@ class MaxwellSimSocket(bpy.types.NodeSocket): #################### # - Initialization #################### - def __init_subclass__(cls, **kwargs: typ.Any): - super().__init_subclass__(**kwargs) + @classmethod + def set_prop( + cls, + prop_name: str, + prop: bpy.types.Property, + no_update: bool = False, + update_with_name: str | None = None, + **kwargs, + ) -> None: + """Adds a Blender property to a class via `__annotations__`, so it initializes with any subclass. - # Setup Blender ID for Node - if not hasattr(cls, 'socket_type'): - msg = f"Socket class {cls} does not define 'socket_type'" - raise ValueError(msg) - cls.bl_idname = str(cls.socket_type.value) + Notes: + - Blender properties can't be set within `__init_subclass__` simply by adding attributes to the class; they must be added as type annotations. + - Must be called **within** `__init_subclass__`. - # Setup Locked Property for Node - cls.__annotations__['locked'] = bpy.props.BoolProperty( - name='Locked State', - description="The lock-state of a particular socket, which determines the socket's user editability", - default=False, + Parameters: + name: The name of the property to set. + prop: The `bpy.types.Property` to instantiate and attach.. + no_update: Don't attach a `self.sync_prop()` callback to the property's `update`. + """ + _update_with_name = prop_name if update_with_name is None else update_with_name + extra_kwargs = ( + { + 'update': lambda self, context: self.sync_prop( + _update_with_name, context + ), + } + if not no_update + else {} ) + cls.__annotations__[prop_name] = prop( + **kwargs, + **extra_kwargs, + ) + + def __init_subclass__(cls, **kwargs: typ.Any): + log.debug('Initializing Socket: %s', cls.socket_type) + super().__init_subclass__(**kwargs) + # cls._assert_attrs_valid() + + # Socket Properties + ## Identifiers + cls.bl_idname: str = str(cls.socket_type.value) + cls.set_prop('instance_id', bpy.props.StringProperty, no_update=True) + + ## Special States + cls.set_prop('locked', bpy.props.BoolProperty, no_update=True, default=False) # Setup Style cls.socket_color = ct.SOCKET_COLORS[cls.socket_type] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py index 1c39ca6..8e3dea7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py @@ -1,11 +1,12 @@ from . import any as any_socket from . import bool as bool_socket -from . import file_path, string +from . import expr, file_path, string AnySocketDef = any_socket.AnySocketDef BoolSocketDef = bool_socket.BoolSocketDef -FilePathSocketDef = file_path.FilePathSocketDef StringSocketDef = string.StringSocketDef +FilePathSocketDef = file_path.FilePathSocketDef +ExprSocketDef = expr.ExprSocketDef BL_REGISTER = [ @@ -13,4 +14,5 @@ BL_REGISTER = [ *bool_socket.BL_REGISTER, *string.BL_REGISTER, *file_path.BL_REGISTER, + *expr.BL_REGISTER, ] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py new file mode 100644 index 0000000..49dc74c --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/expr.py @@ -0,0 +1,80 @@ +import bpy +import sympy as sp + +from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils.pydantic_sympy import SympyExpr + +from ... import bl_cache +from ... import contracts as ct +from .. import base + + +class ExprBLSocket(base.MaxwellSimSocket): + socket_type = ct.SocketType.Expr + bl_label = 'Expr' + + #################### + # - Properties + #################### + raw_value: bpy.props.StringProperty( + name='Expr', + description='Represents a symbolic expression', + default='', + update=(lambda self, context: self.sync_prop('raw_value', context)), + ) + + symbols: list[sp.Symbol] = bl_cache.BLField([]) + ## TODO: Way of assigning assumptions to symbols. + ## TODO: Dynamic add/remove of symbols + + #################### + # - Socket UI + #################### + def draw_value(self, col: bpy.types.UILayout) -> None: + col.prop(self, 'raw_value', text='') + + #################### + # - Computation of Default Value + #################### + @property + def value(self) -> sp.Expr: + return sp.sympify( + self.raw_value, + strict=False, + convert_xor=True, + ).subs(spux.ALL_UNIT_SYMBOLS) + + @value.setter + def value(self, value: str) -> None: + self.raw_value = str(value) + + @property + def lazy_value(self) -> sp.Expr: + return ct.LazyDataValue.from_function( + sp.lambdify(self.symbols, self.value, 'jax'), + free_args=(tuple(str(sym) for sym in self.symbols), frozenset()), + supports_jax=True, + ) + + +#################### +# - Socket Configuration +#################### +class ExprSocketDef(base.SocketDef): + socket_type: ct.SocketType = ct.SocketType.Expr + + _x = sp.Symbol('x', real=True) + symbols: list[SympyExpr] = [_x] + default_expr: SympyExpr = _x + + def init(self, bl_socket: ExprBLSocket) -> None: + bl_socket.value = self.default_expr + bl_socket.symbols = self.symbols + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + ExprBLSocket, +] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py index 1c1b469..97963f9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py @@ -57,3 +57,6 @@ class StringSocketDef(base.SocketDef): BL_REGISTER = [ StringBLSocket, ] + + + diff --git a/src/blender_maxwell/utils/jarray.py b/src/blender_maxwell/utils/jarray.py new file mode 100644 index 0000000..9ff2ccf --- /dev/null +++ b/src/blender_maxwell/utils/jarray.py @@ -0,0 +1,63 @@ +import dataclasses +import typing as typ +from types import MappingProxyType + +import jax +import jax.numpy as jnp +import pandas as pd + +# import jaxtyping as jtyp +import sympy.physics.units as spu +import xarray + +from . import extra_sympy_units as spux +from . import logger + +log = logger.get(__name__) + +DimName: typ.TypeAlias = str +Number: typ.TypeAlias = int | float | complex +NumberRange: typ.TypeAlias = jax.Array + + +@dataclasses.dataclass(kw_only=True) +class JArray: + """Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds.""" + + array: jax.Array + dims: dict[DimName, NumberRange] + dim_units: dict[DimName, spu.Quantity] + + #################### + # - Constructor + #################### + @classmethod + def from_xarray( + cls, + xarr: xarray.DataArray, + dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}), + sort_axis: int = -1, + ) -> typ.Self: + return cls( + array=jnp.sort(jnp.array(xarr.data), axis=sort_axis), + dims={ + dim_name: jnp.array(xarr.get_index(dim_name).values) + for dim_name in xarr.dims + }, + dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims}, + ) + + def idx(self, dim_name: DimName, dim_value: Number) -> int: + found_idx = jnp.searchsorted(self.dims[dim_name], dim_value) + if found_idx == 0: + return found_idx + if found_idx == len(self.dims[dim_name]): + return found_idx - 1 + + left = self.dims[dim_name][found_idx - 1] + right = self.dims[dim_name][found_idx - 1] + return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx + + @property + def dtype(self) -> jnp.dtype: + return self.array.dtype