feat: Math nodes (non-working)

main
Sofus Albert Høgsbro Rose 2024-04-17 16:03:15 +02:00
parent 568fc449e8
commit dfeb65feec
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
21 changed files with 974 additions and 226 deletions

View File

@ -14,6 +14,10 @@ dependencies = [
"networkx==3.2.*", "networkx==3.2.*",
"rich==12.5.*", "rich==12.5.*",
"rtree==1.2.*", "rtree==1.2.*",
"jax[cpu]==0.4.26",
"msgspec[toml]==0.18.6",
"numba==0.59.1",
"jaxtyping==0.2.28",
# Pin Blender 4.1.0-Compatible Versions # Pin Blender 4.1.0-Compatible Versions
## The dependency resolver will report if anything is wonky. ## The dependency resolver will report if anything is wonky.
"urllib3==1.26.8", "urllib3==1.26.8",
@ -22,8 +26,6 @@ dependencies = [
"idna==3.3", "idna==3.3",
"charset-normalizer==2.0.10", "charset-normalizer==2.0.10",
"certifi==2021.10.8", "certifi==2021.10.8",
"jax[cpu]>=0.4.26",
"msgspec[toml]>=0.18.6",
] ]
readme = "README.md" readme = "README.md"
requires-python = "~= 3.11" requires-python = "~= 3.11"

View File

@ -49,11 +49,14 @@ importlib-metadata==6.11.0
jax==0.4.26 jax==0.4.26
jaxlib==0.4.26 jaxlib==0.4.26
# via jax # via jax
jaxtyping==0.2.28
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
kiwisolver==1.4.5 kiwisolver==1.4.5
# via matplotlib # via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0 locket==1.0.0
# via partd # via partd
matplotlib==3.8.3 matplotlib==3.8.3
@ -65,13 +68,16 @@ mpmath==1.3.0
# via sympy # via sympy
msgspec==0.18.6 msgspec==0.18.6
networkx==3.2 networkx==3.2
numba==0.59.1
numpy==1.24.3 numpy==1.24.3
# via contourpy # via contourpy
# via h5py # via h5py
# via jax # via jax
# via jaxlib # via jaxlib
# via jaxtyping
# via matplotlib # via matplotlib
# via ml-dtypes # via ml-dtypes
# via numba
# via opt-einsum # via opt-einsum
# via scipy # via scipy
# via shapely # via shapely
@ -142,6 +148,8 @@ toolz==0.12.1
# via dask # via dask
# via partd # via partd
trimesh==4.2.0 trimesh==4.2.0
typeguard==2.13.3
# via jaxtyping
types-pyyaml==6.0.12.20240311 types-pyyaml==6.0.12.20240311
# via responses # via responses
typing-extensions==4.10.0 typing-extensions==4.10.0

View File

@ -48,11 +48,14 @@ importlib-metadata==6.11.0
jax==0.4.26 jax==0.4.26
jaxlib==0.4.26 jaxlib==0.4.26
# via jax # via jax
jaxtyping==0.2.28
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
kiwisolver==1.4.5 kiwisolver==1.4.5
# via matplotlib # via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0 locket==1.0.0
# via partd # via partd
matplotlib==3.8.3 matplotlib==3.8.3
@ -64,13 +67,16 @@ mpmath==1.3.0
# via sympy # via sympy
msgspec==0.18.6 msgspec==0.18.6
networkx==3.2 networkx==3.2
numba==0.59.1
numpy==1.24.3 numpy==1.24.3
# via contourpy # via contourpy
# via h5py # via h5py
# via jax # via jax
# via jaxlib # via jaxlib
# via jaxtyping
# via matplotlib # via matplotlib
# via ml-dtypes # via ml-dtypes
# via numba
# via opt-einsum # via opt-einsum
# via scipy # via scipy
# via shapely # via shapely
@ -140,6 +146,8 @@ toolz==0.12.1
# via dask # via dask
# via partd # via partd
trimesh==4.2.0 trimesh==4.2.0
typeguard==2.13.3
# via jaxtyping
types-pyyaml==6.0.12.20240311 types-pyyaml==6.0.12.20240311
# via responses # via responses
typing-extensions==4.10.0 typing-extensions==4.10.0

View File

@ -4,8 +4,9 @@ import functools
import typing as typ import typing as typ
from types import MappingProxyType from types import MappingProxyType
# import colour ## TODO import jax
import numpy as np import jax.numpy as jnp
import numba
import sympy as sp import sympy as sp
import sympy.physics.units as spu import sympy.physics.units as spu
import typing_extensions as typx import typing_extensions as typx
@ -15,66 +16,46 @@ from ....utils import sci_constants as constants
from .socket_types import SocketType from .socket_types import SocketType
class DataFlowKind(enum.StrEnum): class FlowKind(enum.StrEnum):
"""Defines a shape/kind of data that may flow through a node tree. """Defines a kind of data that can flow between nodes.
Since a node socket may define one of each, we can support several related kinds of data flow through the same node-graph infrastructure. Each node link can be thought to contain **multiple pipelines for data to flow along**.
Each pipeline is cached incrementally, and independently, of the others.
Thus, the same socket can easily support several kinds of related data flow at the same time.
Attributes: Attributes:
Value: A value without any unknown symbols. Capabilities: Describes a socket's linkeability with other sockets.
- Basic types aka. float, int, list, string, etc. . Links between sockets with incompatible capabilities will be rejected.
- Exotic (immutable-ish) types aka. numpy array, KDTree, etc. . This doesn't need to be defined normally, as there is a default.
- A usable constructed object, ex. a `tidy3d.Box`. However, in some cases, defining it manually to control linkeability more granularly may be desirable.
- Expressions (`sp.Expr`) that don't have unknown variables. Value: A generic object, which is "directly usable".
- Lazy sequences aka. generators, with all data bound. This should be chosen when a more specific flow kind doesn't apply.
SpectralValue: A value defined along a spectral range. Array: An object with dimensions, and possibly a unit.
- {`np.array` Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
LazyValue: An object which, when given new data, can make many values. LazyValueFunc: A composable function.
- An `sp.Expr`, which might need `simplify`ing, `jax` JIT'ing, unit cancellations, variable substitutions, etc. before use. Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
- Lazy objects, for which all parameters aren't yet known. LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
- A computational graph aka. `aesara`, which may even need to be handled before This should be used instead of `Array` whenever possible.
Param: An object providing data to complete `Lazy` data.
Capabilities: A `ValueCapability` object providing compatibility. For example,
Info: An object providing context about other flows.
# Value Data Flow For example,
Simply passing values is the simplest and easiest use case.
This doesn't mean it's "dumb" - ex. a `sp.Expr` might, before use, have `simplify`, rewriting, unit cancellation, etc. run.
All of this is okay, as long as there is no *introduction of new data* ex. variable substitutions.
# Lazy Value Data Flow
By passing (essentially) functions, one supports:
- **Lightness**: While lazy values can be made expensive to construct, they will generally not be nearly as heavy to handle when trying to work with ex. operations on voxel arrays.
- **Performance**: Parameterizing ex. `sp.Expr` with variables allows one to build very optimized functions, which can make ex. node graph updates very fast if the only operation run is the `jax` JIT'ed function (aka. GPU accelerated) generated from the final full expression.
- **Numerical Stability**: Libraries like `aesara` build a computational graph, which can be automatically rewritten to avoid many obvious conditioning / cancellation errors.
- **Lazy Output**: The goal of a node-graph may not be the definition of a single value, but rather, a parameterized expression for generating *many values* with known properties. This is especially interesting for use cases where one wishes to build an optimization step using nodes.
# Capability Passing
By being able to pass "capabilities" next to other kinds of values, nodes can quickly determine whether a given link is valid without having to actually compute it.
# Lazy Parameter Value
When using parameterized LazyValues, one may wish to independently pass parameter values through the graph, so they can be inserted into the final (cached) high-performance expression without.
The advantage of using a different data flow would be changing this kind of value would ONLY invalidate lazy parameter value caches, which would allow an incredibly fast path of getting the value into the lazy expression for high-performance computation.
Implementation TBD - though, ostensibly, one would have a "parameter" node which both would only provide a LazyValue (aka. a symbolic variable), but would also be able to provide a LazyParamValue, which would be a particular value of some kind (probably via the `value` of some other node socket).
""" """
Capabilities = enum.auto() Capabilities = enum.auto()
# Values # Values
Value = enum.auto() Value = enum.auto()
ValueArray = enum.auto() Array = enum.auto()
ValueSpectrum = enum.auto()
# Lazy # Lazy
LazyValue = enum.auto() LazyValue = enum.auto()
LazyValueRange = enum.auto() LazyArrayRange = enum.auto()
LazyValueSpectrum = enum.auto()
# Auxiliary
Param = enum.auto()
Info = enum.auto()
@classmethod @classmethod
def scale_to_unit_system(cls, kind: typ.Self, value, socket_type, unit_system): def scale_to_unit_system(cls, kind: typ.Self, value, socket_type, unit_system):
@ -85,7 +66,7 @@ class DataFlowKind(enum.StrEnum):
unit_system[socket_type], unit_system[socket_type],
) )
) )
if kind == cls.LazyValueRange: if kind == cls.LazyArrayRange:
return value.rescale_to_unit(unit_system[socket_type]) return value.rescale_to_unit(unit_system[socket_type])
msg = 'Tried to scale unknown kind' msg = 'Tried to scale unknown kind'
@ -93,12 +74,12 @@ class DataFlowKind(enum.StrEnum):
#################### ####################
# - Data Structures: Capabilities # - Capabilities
#################### ####################
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class DataCapabilities: class CapabilitiesFlow:
socket_type: SocketType socket_type: SocketType
active_kind: DataFlowKind active_kind: FlowKind
is_universal: bool = False is_universal: bool = False
@ -110,13 +91,16 @@ class DataCapabilities:
#################### ####################
# - Data Structures: Non-Lazy # - Value
#################### ####################
DataValue: typ.TypeAlias = typ.Any ValueFlow: typ.TypeAlias = typ.Any
####################
# - Value Array
####################
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class DataValueArray: class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit. """A simple, flat array of values with an optionally-attached unit.
Attributes: Attributes:
@ -125,69 +109,105 @@ class DataValueArray:
None if unitless. None if unitless.
""" """
values: typ.Sequence[DataValue] values: jax.Array
unit: spu.Quantity | None unit: spu.Quantity | None
####################
# - Lazy Value Func
####################
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], ValueFlow]
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class DataValueSpectrum: class LazyValueFuncFlow:
"""A numerical representation of a spectral distribution. r"""Encapsulates a lazily evaluated data value as a composable function with bound and free arguments.
- **Bound Args**: Arguments that are realized when **defining** the lazy value.
Both positional values and keyword values are supported.
- **Free Args**: Arguments that are specified when evaluating the lazy value.
Both positional values and keyword values are supported.
The **root function** is encapsulated using `from_function`, and must accept arguments in the following order:
$$
f_0:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \underbrace{r_1, r_2, ...}_{\text{Free}}) \to \text{output}_0
$$
Subsequent **composed functions** are encapsulated from the _root function_, and are created with `root_function.compose`.
They must accept arguments in the following order:
$$
f_k:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \text{output}_{k-1} ,\ \underbrace{r_p, r_{p+1}, ...}_{\text{Free}}) \to \text{output}_k
$$
Attributes: Attributes:
wls: A 1D `numpy` float array of wavelength values. function: The function to be lazily evaluated.
wls_unit: The unit of wavelengths, as length dimension. bound_args: Arguments that will be packaged into function, which can't be later modifier.
values: A 1D `numpy` float array of values corresponding to wavelength values. func_kwargs: Arguments to be specified by the user at the time of use.
values_unit: The unit of the value, as arbitrary dimension. supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
freqs_unit: The unit of the value, as arbitrary dimension. supports_numba: Whether the contained `self.function` can be compiled with Numba's JIT compiler.
""" """
# Wavelength func: LazyFunction
wls: np.array func_kwargs: dict[str, type]
wls_unit: spu.Quantity supports_jax: bool = False
supports_numba: bool = False
# Value @staticmethod
values: np.array def from_func(
values_unit: spu.Quantity func: LazyFunction,
supports_jax: bool = False,
supports_numba: bool = False,
**func_kwargs: dict[str, type],
) -> typ.Self:
return LazyValueFuncFlow(
func=func,
func_kwargs=func_kwargs,
supports_jax=supports_jax,
supports_numba=supports_numba,
)
# Frequency # Composition
freqs_unit: spu.Quantity = spu.hertz def compose_within(
self,
enclosing_func: LazyFunction,
supports_jax: bool = False,
supports_numba: bool = False,
**enclosing_func_kwargs: dict[str, type],
) -> typ.Self:
return LazyValueFuncFlow(
function=lambda **kwargs: enclosing_func(
self.func(**{k: v for k, v in kwargs if k in self.func_kwargs}),
**kwargs,
),
func_kwargs=self.func_kwargs | enclosing_func_kwargs,
supports_jax=self.supports_jax and supports_jax,
supports_numba=self.supports_numba and supports_numba,
)
@functools.cached_property @functools.cached_property
def freqs(self) -> np.array: def func_jax(self) -> LazyFunction:
"""The spectral frequencies, computed from the wavelengths. if self.supports_jax:
return jax.jit(self.func)
Frequencies are NOT reversed, so as to preserve the by-index mapping to `DataValueSpectrum.values`. msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)
Returns: @functools.cached_property
Frequencies, as a unitless `numpy` array. def func_numba(self) -> LazyFunction:
Use `DataValueSpectrum.wls_unit` to interpret this return value. if self.supports_numba:
""" return numba.jit(self.func)
unitless_speed_of_light = spux.sympy_to_python(
spux.scale_to_unit(
constants.vac_speed_of_light, (self.wl_unit / self.freq_unit)
)
)
return unitless_speed_of_light / self.wls
# TODO: Colour Library msg = 'Can\'t express LazyValueFuncFlow as Numba function (using numba.jit), since "self.supports_numba" is False'
# def as_colour_sd(self) -> colour.SpectralDistribution: raise ValueError(msg)
# """Returns the `colour` representation of this spectral distribution, ideal for plotting and colorimetric analysis."""
# return colour.SpectralDistribution(data=self.values, domain=self.wls)
#################### ####################
# - Data Structures: Lazy # - Lazy Array Range
#################### ####################
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class LazyDataValue: class LazyArrayRangeFlow:
callback: typ.Callable[[...], [DataValue]]
def realize(self, *args: list[DataValue]) -> DataValue:
return self.callback(*args)
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyDataValueRange:
symbols: set[sp.Symbol] symbols: set[sp.Symbol]
start: sp.Basic start: sp.Basic
@ -200,7 +220,7 @@ class LazyDataValueRange:
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self: def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
if self.has_unit: if self.has_unit:
return LazyDataValueRange( return LazyArrayRangeFlow(
symbols=self.symbols, symbols=self.symbols,
has_unit=self.has_unit, has_unit=self.has_unit,
unit=unit, unit=unit,
@ -219,7 +239,7 @@ class LazyDataValueRange:
reverse: bool = False, reverse: bool = False,
) -> typ.Self: ) -> typ.Self:
"""Call a function on both bounds (start and stop), creating a new `LazyDataValueRange`.""" """Call a function on both bounds (start and stop), creating a new `LazyDataValueRange`."""
return LazyDataValueRange( return LazyArrayRangeFlow(
symbols=self.symbols, symbols=self.symbols,
has_unit=self.has_unit, has_unit=self.has_unit,
unit=self.unit, unit=self.unit,
@ -234,8 +254,8 @@ class LazyDataValueRange:
) )
def realize( def realize(
self, symbol_values: dict[sp.Symbol, DataValue] = MappingProxyType({}) self, symbol_values: dict[sp.Symbol, ValueFlow] = MappingProxyType({})
) -> DataValueArray: ) -> ArrayFlow:
# Realize Symbols # Realize Symbols
if not self.has_unit: if not self.has_unit:
start = spux.sympy_to_python(self.start.subs(symbol_values)) start = spux.sympy_to_python(self.start.subs(symbol_values))
@ -250,85 +270,25 @@ class LazyDataValueRange:
# Return Linspace / Logspace # Return Linspace / Logspace
if self.scaling == 'lin': if self.scaling == 'lin':
return DataValueArray( return ArrayFlow(
values=np.linspace(start, stop, self.steps), unit=self.unit values=jnp.linspace(start, stop, self.steps), unit=self.unit
) )
if self.scaling == 'geom': if self.scaling == 'geom':
return DataValueArray(np.geomspace(start, stop, self.steps), self.unit) return ArrayFlow(jnp.geomspace(start, stop, self.steps), self.unit)
if self.scaling == 'log': if self.scaling == 'log':
return DataValueArray(np.logspace(start, stop, self.steps), self.unit) return ArrayFlow(jnp.logspace(start, stop, self.steps), self.unit)
raise NotImplementedError msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
@dataclasses.dataclass(frozen=True, kw_only=True) ####################
class LazyDataValueSpectrum: # - Param
wl_unit: spu.Quantity ####################
value_unit: spu.Quantity ParamFlow: typ.TypeAlias = dict[str, typ.Any]
value_expr: sp.Expr
symbols: tuple[sp.Symbol, ...] = ()
freq_symbol: sp.Symbol = sp.Symbol('lamda') # noqa: RUF009
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
raise NotImplementedError
@functools.cached_property
def as_func(self) -> typ.Callable[[DataValue, ...], DataValue]:
"""Generates an optimized function for numerical evaluation of the spectral expression."""
return sp.lambdify([self.freq_symbol, *self.symbols], self.value_expr)
def realize(
self, wl_range: DataValueArray, symbol_values: tuple[DataValue, ...]
) -> DataValueSpectrum:
r"""Realizes the parameterized spectral function as a numerical spectral distribution.
Parameters:
wl_range: The lazy wavelength range to build the concrete spectral distribution with.
symbol_values: Numerical values for each symbol, in the same order as defined in `LazyDataValueSpectrum.symbols`.
The wavelength symbol ($\lambda$ by default) always goes first.
_This is used to call the spectral function using the output of `.as_func()`._
Returns:
The concrete, numerical spectral distribution.
"""
return DataValueSpectrum(
wls=wl_range.values,
wls_unit=self.wl_unit,
values=self.as_func(*list(symbol_values.values())),
values_unit=self.value_unit,
)
# ####################
# # - Lazy Value Func
##################### ####################
## - Data Pipeline InfoFlow: typ.TypeAlias = dict[str, typ.Any]
#####################
# @dataclasses.dataclass(frozen=True, kw_only=True)
# class DataPipelineDim:
# unit: spu.Quantity | None
#
# class DataPipelineDimType(enum.StrEnum):
# # Map Inputs
# Time = enum.auto()
# Freq = enum.auto()
# Space3D = enum.auto()
# DiffOrder = enum.auto()
#
# # Map Inputs
# Power = enum.auto()
# EVec = enum.auto()
# HVec = enum.auto()
# RelPerm = enum.auto()
#
#
# @dataclasses.dataclass(frozen=True, kw_only=True)
# class LazyDataPipeline:
# dims: list[DataPipelineDim]
#
# def _callable(self):
# """JITs the current pipeline of functions with `jax`."""
#
# def __call__(self):
# pass

View File

@ -7,6 +7,7 @@ SOCKET_COLORS = {
ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey
ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey
# Number # Number
ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue
ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue

View File

@ -6,6 +6,7 @@ SOCKET_SHAPES = {
ST.Bool: 'CIRCLE', ST.Bool: 'CIRCLE',
ST.String: 'CIRCLE', ST.String: 'CIRCLE',
ST.FilePath: 'CIRCLE', ST.FilePath: 'CIRCLE',
ST.Expr: 'CIRCLE',
# Number # Number
ST.IntegerNumber: 'CIRCLE', ST.IntegerNumber: 'CIRCLE',
ST.RationalNumber: 'CIRCLE', ST.RationalNumber: 'CIRCLE',

View File

@ -14,6 +14,7 @@ class SocketType(BlenderTypeEnum):
String = enum.auto() String = enum.auto()
FilePath = enum.auto() FilePath = enum.auto()
Color = enum.auto() Color = enum.auto()
Expr = enum.auto()
# Number # Number
IntegerNumber = enum.auto() IntegerNumber = enum.auto()

View File

@ -38,8 +38,8 @@ def apply_colormap(normalized_data, colormap):
@jax.jit @jax.jit
def rgba_image_from_xyzf__viridis(xyz_freq): def rgba_image_from_2d_map__viridis(map_2d):
amplitude = jnp.abs(jnp.squeeze(xyz_freq)) amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / ( amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min() amplitude.max() - amplitude.min()
) )
@ -49,8 +49,8 @@ def rgba_image_from_xyzf__viridis(xyz_freq):
@jax.jit @jax.jit
def rgba_image_from_xyzf__grayscale(xyz_freq): def rgba_image_from_2d_map__grayscale(map_2d):
amplitude = jnp.abs(jnp.squeeze(xyz_freq)) amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / ( amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min() amplitude.max() - amplitude.min()
) )
@ -59,21 +59,19 @@ def rgba_image_from_xyzf__grayscale(xyz_freq):
return jnp.dstack((rgb_array, alpha_channel)) return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None): def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
"""RGBA Image from Squeezable XYZ-Freq w/fixed freq. """RGBA Image from a map of 2D coordinates to values.
Parameters: Parameters:
xyz_freq: Shape (xlen, ylen, zlen), one dimension has length 1. map_2d: Shape (width, height, value).
width_px: Pixel width to resize the image to.
height: Pixel height to resize the image to.
Returns: Returns:
Image as a JAX array of shape (height, width, 3) Image as a JAX array of shape (height, width, 4)
""" """
if colormap == 'VIRIDIS': if colormap == 'VIRIDIS':
return rgba_image_from_xyzf__viridis(xyz_freq) return rgba_image_from_2d_map__viridis(map_2d)
if colormap == 'GRAYSCALE': if colormap == 'GRAYSCALE':
return rgba_image_from_xyzf__grayscale(xyz_freq) return rgba_image_from_2d_map__grayscale(map_2d)
class ManagedBLImage(base.ManagedObj): class ManagedBLImage(base.ManagedObj):
@ -227,11 +225,11 @@ class ManagedBLImage(base.ManagedObj):
#################### ####################
# - Special Methods # - Special Methods
#################### ####################
def xyzf_to_image( def map_2d_to_image(
self, xyz_freq, colormap: str | None = 'VIRIDIS', bl_select: bool = False self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
): ):
self.data_to_image( self.data_to_image(
lambda _: rgba_image_from_xyzf(xyz_freq, colormap=colormap), lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap),
bl_select=bl_select, bl_select=bl_select,
) )

View File

@ -1,10 +1,12 @@
from . import extract_data, viz from . import extract_data, viz, math
BL_REGISTER = [ BL_REGISTER = [
*extract_data.BL_REGISTER, *extract_data.BL_REGISTER,
*viz.BL_REGISTER, *viz.BL_REGISTER,
*math.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**extract_data.BL_NODES, **extract_data.BL_NODES,
**viz.BL_NODES, **viz.BL_NODES,
**math.BL_NODES,
} }

View File

@ -1,8 +1,11 @@
import typing as typ import typing as typ
import bpy import bpy
import jax.numpy as jnp
import sympy.physics.units as spu
from blender_maxwell.utils import jarray, logger
from .....utils import logger
from ... import contracts as ct from ... import contracts as ct
from ... import sockets from ... import sockets
from .. import base, events from .. import base, events
@ -229,8 +232,10 @@ class ExtractDataNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
props={'sim_data__monitor_name', 'field_data__component'}, props={'sim_data__monitor_name', 'field_data__component'},
input_sockets={'Field Data'},
input_sockets_optional={'Field Data': True},
) )
def compute_extracted_data(self, props: dict): def compute_extracted_data(self, props: dict, input_sockets: dict):
if self.active_socket_set == 'Sim Data': if self.active_socket_set == 'Sim Data':
if ( if (
CACHE_SIM_DATA.get(self.instance_id) is None CACHE_SIM_DATA.get(self.instance_id) is None
@ -242,12 +247,21 @@ class ExtractDataNode(base.MaxwellSimNode):
return sim_data.monitor_data[props['sim_data__monitor_name']] return sim_data.monitor_data[props['sim_data__monitor_name']]
elif self.active_socket_set == 'Field Data': # noqa: RET505 elif self.active_socket_set == 'Field Data': # noqa: RET505
field_data = self._compute_input('Field Data') xarr = getattr(input_sockets['Field Data'], props['field_data__component'])
return getattr(field_data, props['field_data__component'])
return jarray.JArray.from_xarray(
xarr,
dim_units={
'x': spu.um,
'y': spu.um,
'z': spu.um,
'f': spu.hertz,
},
)
elif self.active_socket_set == 'Flux Data': elif self.active_socket_set == 'Flux Data':
flux_data = self._compute_input('Flux Data') flux_data = self._compute_input('Flux Data')
return flux_data.flux return jnp.array(flux_data.flux)
msg = f'Tried to get data from unknown output socket in "{self.bl_label}"' msg = f'Tried to get data from unknown output socket in "{self.bl_label}"'
raise RuntimeError(msg) raise RuntimeError(msg)

View File

@ -0,0 +1,14 @@
from . import map_math, filter_math, reduce_math, operate_math
BL_REGISTER = [
*map_math.BL_REGISTER,
*filter_math.BL_REGISTER,
*reduce_math.BL_REGISTER,
*operate_math.BL_REGISTER,
]
BL_NODES = {
**map_math.BL_NODES,
**filter_math.BL_NODES,
**reduce_math.BL_NODES,
**operate_math.BL_NODES,
}

View File

@ -0,0 +1,121 @@
import functools
import typing as typ
import bpy
import jax
import jax.numpy as jnp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
# @functools.partial(jax.jit, static_argnames=('fixed_axis', 'fixed_axis_value'))
# jax.jit
def fix_axis(data, fixed_axis: int, fixed_axis_value: float):
log.critical(data.shape)
# Select Values of Fixed Axis
fixed_axis_values = data[
tuple(slice(None) if i == fixed_axis else 0 for i in range(data.ndim))
]
log.critical(fixed_axis_values)
# Compute Nearest Index on Fixed Axis
idx_of_nearest = jnp.argmin(jnp.abs(fixed_axis_values - fixed_axis_value))
log.critical(idx_of_nearest)
# Select Values along Fixed Axis Value
return jnp.take(data, idx_of_nearest, axis=fixed_axis)
class FilterMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.FilterMath
bl_label = 'Filter Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Axis Value': {
'Axis': sockets.IntegerNumberSocketDef(),
'Value': sockets.RealNumberSocketDef(),
},
'By Axis': {
'Axis': sockets.IntegerNumberSocketDef(),
},
## TODO: bool arrays for comparison/switching/sparse 0-setting/etc. .
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to reduce the input axis with',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Axis Value':
items += [
('FIX', 'Fix Coordinate', '(*, N, *) -> (*, *)'),
]
if self.active_socket_set == 'By Axis':
items += [
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
]
else:
items += [('NONE', 'None', 'No operations...')]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Axis Expr':
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'operation', 'active_socket_set'},
input_sockets={'Data', 'Axis', 'Value'},
input_sockets_optional={'Axis': True, 'Value': True},
)
def compute_data(self, props: dict, input_sockets: dict):
if not hasattr(input_sockets['Data'], 'shape'):
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
raise ValueError(msg)
# By Axis Value
if props['active_socket_set'] == 'By Axis Value':
if props['operation'] == 'FIX':
return fix_axis(
input_sockets['Data'], input_sockets['Axis'], input_sockets['Value']
)
# By Axis
if props['active_socket_set'] == 'By Axis':
if props['operation'] == 'SQUEEZE':
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
msg = 'Operation invalid'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
FilterMathNode,
]
BL_NODES = {ct.NodeType.FilterMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,164 @@
import typing as typ
import bpy
import jax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class MapMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.MapMath
bl_label = 'Map Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Element': {},
'By Vector': {},
'By Matrix': {},
'Expr': {
'Mapper': sockets.ExprSocketDef(
symbols=[sp.Symbol('x')],
default_expr=sp.Symbol('x'),
),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to apply to the input',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Element':
items += [
# General
('REAL', 'real', '(L) (by el)'),
('IMAG', 'imag', 'Im(L) (by el)'),
('ABS', 'abs', '|L| (by el)'),
('SQ', 'square', 'L^2 (by el)'),
('SQRT', 'sqrt', 'sqrt(L) (by el)'),
('INV_SQRT', '1/sqrt', '1/sqrt(L) (by el)'),
# Trigonometry
('COS', 'cos', 'cos(L) (by el)'),
('SIN', 'sin', 'sin(L) (by el)'),
('TAN', 'tan', 'tan(L) (by el)'),
('ACOS', 'acos', 'acos(L) (by el)'),
('ASIN', 'asin', 'asin(L) (by el)'),
('ATAN', 'atan', 'atan(L) (by el)'),
]
elif self.active_socket_set in 'By Vector':
items += [
('NORM_2', '2-Norm', '||L||_2 (by Vec)'),
]
elif self.active_socket_set == 'By Matrix':
items += [
# Matrix -> Number
('DET', 'Determinant', 'det(L) (by Mat)'),
('COND', 'Condition', 'κ(L) (by Mat)'),
('NORM_FRO', 'Frobenius Norm', '||L||_F (by Mat)'),
('RANK', 'Rank', 'rank(L) (by Mat)'),
# Matrix -> Array
('DIAG', 'Diagonal', 'diag(L) (by Mat)'),
('EIG_VALS', 'Eigenvalues', 'eigvals(L) (by Mat)'),
('SVD_VALS', 'SVD', 'svd(L) -> diag(Σ) (by Mat)'),
# Matrix -> Matrix
('INV', 'Invert', 'L^(-1) (by Mat)'),
('TRA', 'Transpose', 'L^T (by Mat)'),
# Matrix -> Matrices
('QR', 'QR', 'L -> Q·R (by Mat)'),
('CHOL', 'Cholesky', 'L -> L·Lh (by Mat)'),
('SVD', 'SVD', 'L -> U·Σ·Vh (by Mat)'),
]
else:
items += ['EXPR_EL', 'Expr (by el)', 'Expression-defined (by el)']
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set not in {'Expr (Element)'}:
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'active_socket_set', 'operation'},
input_sockets={'Data', 'Mapper'},
input_socket_kinds={'Mapper': ct.DataFlowKind.LazyValue},
input_sockets_optional={'Mapper': True},
)
def compute_data(self, props: dict, input_sockets: dict):
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
'By Element': {
'REAL': lambda data: jnp.real(data),
'IMAG': lambda data: jnp.imag(data),
'ABS': lambda data: jnp.abs(data),
'SQ': lambda data: jnp.square(data),
'SQRT': lambda data: jnp.sqrt(data),
'INV_SQRT': lambda data: 1 / jnp.sqrt(data),
'COS': lambda data: jnp.cos(data),
'SIN': lambda data: jnp.sin(data),
'TAN': lambda data: jnp.tan(data),
'ACOS': lambda data: jnp.acos(data),
'ASIN': lambda data: jnp.asin(data),
'ATAN': lambda data: jnp.atan(data),
'SINC': lambda data: jnp.sinc(data),
},
'By Vector': {
'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1),
},
'By Matrix': {
# Matrix -> Number
'DET': lambda data: jnp.linalg.det(data),
'COND': lambda data: jnp.linalg.cond(data),
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
'RANK': lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec
'DIAG': lambda data: jnp.diag(data),
'EIG_VALS': lambda data: jnp.eigvals(data),
'SVD_VALS': lambda data: jnp.svdvals(data),
# Matrix -> Matrix
'INV': lambda data: jnp.inv(data),
'TRA': lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices
'QR': lambda data: jnp.inv(data),
'CHOL': lambda data: jnp.linalg.cholesky(data),
'SVD': lambda data: jnp.linalg.svd(data),
},
'By El (Expr)': {
'EXPR_EL': lambda data: input_sockets['Mapper'](data),
},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data
return input_sockets['Data'].compose(
function=mapping_func,
)
####################
# - Blender Registration
####################
BL_REGISTER = [
MapMathNode,
]
BL_NODES = {ct.NodeType.MapMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,138 @@
import typing as typ
import bpy
import jax.numpy as jnp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class OperateMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.OperateMath
bl_label = 'Operate Math'
input_socket_sets: typ.ClassVar = {
'Elementwise': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
## TODO: Filter-array building operations
'Vec-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
'Mat-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to apply to the two inputs',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'Elementwise':
items = [
('ADD', 'Add', 'L + R (by el)'),
('SUB', 'Subtract', 'L - R (by el)'),
('MUL', 'Multiply', 'L · R (by el)'),
('DIV', 'Divide', 'L ÷ R (by el)'),
('POW', 'Power', 'L^R (by el)'),
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
]
elif self.active_socket_set in 'Vec | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('CROSS', 'Cross', 'L x R (by last-axis'),
]
elif self.active_socket_set == 'Mat | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, 'operation')
####################
# - Properties
####################
@events.computes_output_socket(
'Data',
props={'operation'},
input_sockets={'Data L', 'Data R'},
)
def compute_data(self, props: dict, input_sockets: dict):
if self.active_socket_set == 'Elementwise':
# Element-Wise Arithmetic
if props['operation'] == 'ADD':
return input_sockets['Data L'] + input_sockets['Data R']
if props['operation'] == 'SUB':
return input_sockets['Data L'] - input_sockets['Data R']
if props['operation'] == 'MUL':
return input_sockets['Data L'] * input_sockets['Data R']
if props['operation'] == 'DIV':
return input_sockets['Data L'] / input_sockets['Data R']
# Element-Wise Arithmetic
if props['operation'] == 'POW':
return input_sockets['Data L'] ** input_sockets['Data R']
# Binary Trigonometry
if props['operation'] == 'ATAN2':
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
# Special Functions
if props['operation'] == 'HEAVISIDE':
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
# Linear Algebra
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
if props['operation'] == 'DOT':
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
elif self.active_socket_set == 'Vec-Vec':
if props['operation'] == 'CROSS':
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
elif self.active_socket_set == 'Mat-Vec':
if props['operation'] == 'LIN_SOLVE':
return jnp.linalg.lstsq(
input_sockets['Data L'], input_sockets['Data R']
)
if props['operation'] == 'LSQ_SOLVE':
return jnp.linalg.solve(
input_sockets['Data L'], input_sockets['Data R']
)
msg = 'Invalid operation'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
OperateMathNode,
]
BL_NODES = {ct.NodeType.OperateMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,135 @@
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class ReduceMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.ReduceMath
bl_label = 'Reduce Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
'Axis': sockets.IntegerNumberSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Axis': {
'Axis': sockets.IntegerNumberSocketDef(),
},
'Expr': {
'Reducer': sockets.ExprSocketDef(
symbols=[sp.Symbol('a'), sp.Symbol('b')],
default_expr=sp.Symbol('a') + sp.Symbol('b'),
),
'Axis': sockets.IntegerNumberSocketDef(),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to reduce the input axis with',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Axis':
items += [
# Accumulation
('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'),
('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'),
('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'),
('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'),
('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'),
# Stats
('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'),
('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'),
('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'),
('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'),
# Dimension Reduction
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
]
else:
items += [('NONE', 'None', 'No operations...')]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Axis Expr':
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'operation'},
input_sockets={'Data', 'Axis', 'Reducer'},
input_socket_kinds={'Reducer': ct.DataFlowKind.LazyValue},
input_sockets_optional={'Reducer': True},
)
def compute_data(self, props: dict, input_sockets: dict):
if not hasattr(input_sockets['Data'], 'shape'):
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
raise ValueError(msg)
if self.active_socket_set == 'Axis Expr':
ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1)
return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis'])
if self.active_socket_set == 'By Axis':
## Dimension Reduction
# ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
# Accumulation
if props['operation'] == 'SUM':
return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'PROD':
return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MIN':
return jnp.min(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MAX':
return jnp.max(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'P2P':
return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis'])
# Stats
if props['operation'] == 'MEAN':
return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MEDIAN':
return jnp.median(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'STDDEV':
return jnp.std(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'VARIANCE':
return jnp.var(input_sockets['Data'], axis=input_sockets['Axis'])
# Dimension Reduction
if props['operation'] == 'SQUEEZE':
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
msg = 'Operation invalid'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
ReduceMathNode,
]
BL_NODES = {ct.NodeType.ReduceMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -20,7 +20,7 @@ class VizNode(base.MaxwellSimNode):
#################### ####################
# - Sockets # - Sockets
#################### ####################
input_sockets = { input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(), 'Data': sockets.AnySocketDef(),
'Freq': sockets.PhysicalFreqSocketDef(), 'Freq': sockets.PhysicalFreqSocketDef(),
} }
@ -72,20 +72,12 @@ class VizNode(base.MaxwellSimNode):
props: dict, props: dict,
unit_systems: dict, unit_systems: dict,
): ):
selected_data = jnp.array( managed_objs['plot'].map_2d_to_image(
input_sockets['Data'].sel(f=input_sockets['Freq'], method='nearest') input_sockets['Data'].as_bound_jax_func(),
)
managed_objs['plot'].xyzf_to_image(
selected_data,
colormap=props['colormap'], colormap=props['colormap'],
bl_select=True, bl_select=True,
) )
# @events.on_init()
# def on_init(self):
# self.on_changed_inputs()
#################### ####################
# - Blender Registration # - Blender Registration

View File

@ -42,6 +42,15 @@ class SocketDef(pyd.BaseModel, abc.ABC):
# - SocketDef # - SocketDef
#################### ####################
class MaxwellSimSocket(bpy.types.NodeSocket): class MaxwellSimSocket(bpy.types.NodeSocket):
"""A specialized Blender socket for nodes in a Maxwell simulation.
Attributes:
instance_id: A unique ID attached to a particular socket instance.
Guaranteed to be unchanged so long as the socket lives.
Used as a socket-specific cache index.
locked: The lock-state of a particular socket, which determines the socket's user editability
"""
# Fundamentals # Fundamentals
socket_type: ct.SocketType socket_type: ct.SocketType
bl_label: str bl_label: str
@ -73,21 +82,53 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
#################### ####################
# - Initialization # - Initialization
#################### ####################
def __init_subclass__(cls, **kwargs: typ.Any): @classmethod
super().__init_subclass__(**kwargs) def set_prop(
cls,
prop_name: str,
prop: bpy.types.Property,
no_update: bool = False,
update_with_name: str | None = None,
**kwargs,
) -> None:
"""Adds a Blender property to a class via `__annotations__`, so it initializes with any subclass.
# Setup Blender ID for Node Notes:
if not hasattr(cls, 'socket_type'): - Blender properties can't be set within `__init_subclass__` simply by adding attributes to the class; they must be added as type annotations.
msg = f"Socket class {cls} does not define 'socket_type'" - Must be called **within** `__init_subclass__`.
raise ValueError(msg)
cls.bl_idname = str(cls.socket_type.value)
# Setup Locked Property for Node Parameters:
cls.__annotations__['locked'] = bpy.props.BoolProperty( name: The name of the property to set.
name='Locked State', prop: The `bpy.types.Property` to instantiate and attach..
description="The lock-state of a particular socket, which determines the socket's user editability", no_update: Don't attach a `self.sync_prop()` callback to the property's `update`.
default=False, """
_update_with_name = prop_name if update_with_name is None else update_with_name
extra_kwargs = (
{
'update': lambda self, context: self.sync_prop(
_update_with_name, context
),
}
if not no_update
else {}
) )
cls.__annotations__[prop_name] = prop(
**kwargs,
**extra_kwargs,
)
def __init_subclass__(cls, **kwargs: typ.Any):
log.debug('Initializing Socket: %s', cls.socket_type)
super().__init_subclass__(**kwargs)
# cls._assert_attrs_valid()
# Socket Properties
## Identifiers
cls.bl_idname: str = str(cls.socket_type.value)
cls.set_prop('instance_id', bpy.props.StringProperty, no_update=True)
## Special States
cls.set_prop('locked', bpy.props.BoolProperty, no_update=True, default=False)
# Setup Style # Setup Style
cls.socket_color = ct.SOCKET_COLORS[cls.socket_type] cls.socket_color = ct.SOCKET_COLORS[cls.socket_type]

View File

@ -1,11 +1,12 @@
from . import any as any_socket from . import any as any_socket
from . import bool as bool_socket from . import bool as bool_socket
from . import file_path, string from . import expr, file_path, string
AnySocketDef = any_socket.AnySocketDef AnySocketDef = any_socket.AnySocketDef
BoolSocketDef = bool_socket.BoolSocketDef BoolSocketDef = bool_socket.BoolSocketDef
FilePathSocketDef = file_path.FilePathSocketDef
StringSocketDef = string.StringSocketDef StringSocketDef = string.StringSocketDef
FilePathSocketDef = file_path.FilePathSocketDef
ExprSocketDef = expr.ExprSocketDef
BL_REGISTER = [ BL_REGISTER = [
@ -13,4 +14,5 @@ BL_REGISTER = [
*bool_socket.BL_REGISTER, *bool_socket.BL_REGISTER,
*string.BL_REGISTER, *string.BL_REGISTER,
*file_path.BL_REGISTER, *file_path.BL_REGISTER,
*expr.BL_REGISTER,
] ]

View File

@ -0,0 +1,80 @@
import bpy
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils.pydantic_sympy import SympyExpr
from ... import bl_cache
from ... import contracts as ct
from .. import base
class ExprBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.Expr
bl_label = 'Expr'
####################
# - Properties
####################
raw_value: bpy.props.StringProperty(
name='Expr',
description='Represents a symbolic expression',
default='',
update=(lambda self, context: self.sync_prop('raw_value', context)),
)
symbols: list[sp.Symbol] = bl_cache.BLField([])
## TODO: Way of assigning assumptions to symbols.
## TODO: Dynamic add/remove of symbols
####################
# - Socket UI
####################
def draw_value(self, col: bpy.types.UILayout) -> None:
col.prop(self, 'raw_value', text='')
####################
# - Computation of Default Value
####################
@property
def value(self) -> sp.Expr:
return sp.sympify(
self.raw_value,
strict=False,
convert_xor=True,
).subs(spux.ALL_UNIT_SYMBOLS)
@value.setter
def value(self, value: str) -> None:
self.raw_value = str(value)
@property
def lazy_value(self) -> sp.Expr:
return ct.LazyDataValue.from_function(
sp.lambdify(self.symbols, self.value, 'jax'),
free_args=(tuple(str(sym) for sym in self.symbols), frozenset()),
supports_jax=True,
)
####################
# - Socket Configuration
####################
class ExprSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.Expr
_x = sp.Symbol('x', real=True)
symbols: list[SympyExpr] = [_x]
default_expr: SympyExpr = _x
def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.value = self.default_expr
bl_socket.symbols = self.symbols
####################
# - Blender Registration
####################
BL_REGISTER = [
ExprBLSocket,
]

View File

@ -57,3 +57,6 @@ class StringSocketDef(base.SocketDef):
BL_REGISTER = [ BL_REGISTER = [
StringBLSocket, StringBLSocket,
] ]

View File

@ -0,0 +1,63 @@
import dataclasses
import typing as typ
from types import MappingProxyType
import jax
import jax.numpy as jnp
import pandas as pd
# import jaxtyping as jtyp
import sympy.physics.units as spu
import xarray
from . import extra_sympy_units as spux
from . import logger
log = logger.get(__name__)
DimName: typ.TypeAlias = str
Number: typ.TypeAlias = int | float | complex
NumberRange: typ.TypeAlias = jax.Array
@dataclasses.dataclass(kw_only=True)
class JArray:
"""Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds."""
array: jax.Array
dims: dict[DimName, NumberRange]
dim_units: dict[DimName, spu.Quantity]
####################
# - Constructor
####################
@classmethod
def from_xarray(
cls,
xarr: xarray.DataArray,
dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}),
sort_axis: int = -1,
) -> typ.Self:
return cls(
array=jnp.sort(jnp.array(xarr.data), axis=sort_axis),
dims={
dim_name: jnp.array(xarr.get_index(dim_name).values)
for dim_name in xarr.dims
},
dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims},
)
def idx(self, dim_name: DimName, dim_value: Number) -> int:
found_idx = jnp.searchsorted(self.dims[dim_name], dim_value)
if found_idx == 0:
return found_idx
if found_idx == len(self.dims[dim_name]):
return found_idx - 1
left = self.dims[dim_name][found_idx - 1]
right = self.dims[dim_name][found_idx - 1]
return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx
@property
def dtype(self) -> jnp.dtype:
return self.array.dtype