feat: Math nodes (non-working)
parent
568fc449e8
commit
dfeb65feec
|
@ -14,6 +14,10 @@ dependencies = [
|
|||
"networkx==3.2.*",
|
||||
"rich==12.5.*",
|
||||
"rtree==1.2.*",
|
||||
"jax[cpu]==0.4.26",
|
||||
"msgspec[toml]==0.18.6",
|
||||
"numba==0.59.1",
|
||||
"jaxtyping==0.2.28",
|
||||
# Pin Blender 4.1.0-Compatible Versions
|
||||
## The dependency resolver will report if anything is wonky.
|
||||
"urllib3==1.26.8",
|
||||
|
@ -22,8 +26,6 @@ dependencies = [
|
|||
"idna==3.3",
|
||||
"charset-normalizer==2.0.10",
|
||||
"certifi==2021.10.8",
|
||||
"jax[cpu]>=0.4.26",
|
||||
"msgspec[toml]>=0.18.6",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = "~= 3.11"
|
||||
|
|
|
@ -49,11 +49,14 @@ importlib-metadata==6.11.0
|
|||
jax==0.4.26
|
||||
jaxlib==0.4.26
|
||||
# via jax
|
||||
jaxtyping==0.2.28
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
llvmlite==0.42.0
|
||||
# via numba
|
||||
locket==1.0.0
|
||||
# via partd
|
||||
matplotlib==3.8.3
|
||||
|
@ -65,13 +68,16 @@ mpmath==1.3.0
|
|||
# via sympy
|
||||
msgspec==0.18.6
|
||||
networkx==3.2
|
||||
numba==0.59.1
|
||||
numpy==1.24.3
|
||||
# via contourpy
|
||||
# via h5py
|
||||
# via jax
|
||||
# via jaxlib
|
||||
# via jaxtyping
|
||||
# via matplotlib
|
||||
# via ml-dtypes
|
||||
# via numba
|
||||
# via opt-einsum
|
||||
# via scipy
|
||||
# via shapely
|
||||
|
@ -142,6 +148,8 @@ toolz==0.12.1
|
|||
# via dask
|
||||
# via partd
|
||||
trimesh==4.2.0
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
types-pyyaml==6.0.12.20240311
|
||||
# via responses
|
||||
typing-extensions==4.10.0
|
||||
|
|
|
@ -48,11 +48,14 @@ importlib-metadata==6.11.0
|
|||
jax==0.4.26
|
||||
jaxlib==0.4.26
|
||||
# via jax
|
||||
jaxtyping==0.2.28
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
llvmlite==0.42.0
|
||||
# via numba
|
||||
locket==1.0.0
|
||||
# via partd
|
||||
matplotlib==3.8.3
|
||||
|
@ -64,13 +67,16 @@ mpmath==1.3.0
|
|||
# via sympy
|
||||
msgspec==0.18.6
|
||||
networkx==3.2
|
||||
numba==0.59.1
|
||||
numpy==1.24.3
|
||||
# via contourpy
|
||||
# via h5py
|
||||
# via jax
|
||||
# via jaxlib
|
||||
# via jaxtyping
|
||||
# via matplotlib
|
||||
# via ml-dtypes
|
||||
# via numba
|
||||
# via opt-einsum
|
||||
# via scipy
|
||||
# via shapely
|
||||
|
@ -140,6 +146,8 @@ toolz==0.12.1
|
|||
# via dask
|
||||
# via partd
|
||||
trimesh==4.2.0
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
types-pyyaml==6.0.12.20240311
|
||||
# via responses
|
||||
typing-extensions==4.10.0
|
||||
|
|
|
@ -4,8 +4,9 @@ import functools
|
|||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
# import colour ## TODO
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numba
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
import typing_extensions as typx
|
||||
|
@ -15,66 +16,46 @@ from ....utils import sci_constants as constants
|
|||
from .socket_types import SocketType
|
||||
|
||||
|
||||
class DataFlowKind(enum.StrEnum):
|
||||
"""Defines a shape/kind of data that may flow through a node tree.
|
||||
class FlowKind(enum.StrEnum):
|
||||
"""Defines a kind of data that can flow between nodes.
|
||||
|
||||
Since a node socket may define one of each, we can support several related kinds of data flow through the same node-graph infrastructure.
|
||||
Each node link can be thought to contain **multiple pipelines for data to flow along**.
|
||||
Each pipeline is cached incrementally, and independently, of the others.
|
||||
Thus, the same socket can easily support several kinds of related data flow at the same time.
|
||||
|
||||
Attributes:
|
||||
Value: A value without any unknown symbols.
|
||||
- Basic types aka. float, int, list, string, etc. .
|
||||
- Exotic (immutable-ish) types aka. numpy array, KDTree, etc. .
|
||||
- A usable constructed object, ex. a `tidy3d.Box`.
|
||||
- Expressions (`sp.Expr`) that don't have unknown variables.
|
||||
- Lazy sequences aka. generators, with all data bound.
|
||||
SpectralValue: A value defined along a spectral range.
|
||||
- {`np.array`
|
||||
|
||||
LazyValue: An object which, when given new data, can make many values.
|
||||
- An `sp.Expr`, which might need `simplify`ing, `jax` JIT'ing, unit cancellations, variable substitutions, etc. before use.
|
||||
- Lazy objects, for which all parameters aren't yet known.
|
||||
- A computational graph aka. `aesara`, which may even need to be handled before
|
||||
|
||||
Capabilities: A `ValueCapability` object providing compatibility.
|
||||
|
||||
# Value Data Flow
|
||||
Simply passing values is the simplest and easiest use case.
|
||||
|
||||
This doesn't mean it's "dumb" - ex. a `sp.Expr` might, before use, have `simplify`, rewriting, unit cancellation, etc. run.
|
||||
All of this is okay, as long as there is no *introduction of new data* ex. variable substitutions.
|
||||
|
||||
|
||||
# Lazy Value Data Flow
|
||||
By passing (essentially) functions, one supports:
|
||||
- **Lightness**: While lazy values can be made expensive to construct, they will generally not be nearly as heavy to handle when trying to work with ex. operations on voxel arrays.
|
||||
- **Performance**: Parameterizing ex. `sp.Expr` with variables allows one to build very optimized functions, which can make ex. node graph updates very fast if the only operation run is the `jax` JIT'ed function (aka. GPU accelerated) generated from the final full expression.
|
||||
- **Numerical Stability**: Libraries like `aesara` build a computational graph, which can be automatically rewritten to avoid many obvious conditioning / cancellation errors.
|
||||
- **Lazy Output**: The goal of a node-graph may not be the definition of a single value, but rather, a parameterized expression for generating *many values* with known properties. This is especially interesting for use cases where one wishes to build an optimization step using nodes.
|
||||
|
||||
|
||||
# Capability Passing
|
||||
By being able to pass "capabilities" next to other kinds of values, nodes can quickly determine whether a given link is valid without having to actually compute it.
|
||||
|
||||
|
||||
# Lazy Parameter Value
|
||||
When using parameterized LazyValues, one may wish to independently pass parameter values through the graph, so they can be inserted into the final (cached) high-performance expression without.
|
||||
|
||||
The advantage of using a different data flow would be changing this kind of value would ONLY invalidate lazy parameter value caches, which would allow an incredibly fast path of getting the value into the lazy expression for high-performance computation.
|
||||
|
||||
Implementation TBD - though, ostensibly, one would have a "parameter" node which both would only provide a LazyValue (aka. a symbolic variable), but would also be able to provide a LazyParamValue, which would be a particular value of some kind (probably via the `value` of some other node socket).
|
||||
Capabilities: Describes a socket's linkeability with other sockets.
|
||||
Links between sockets with incompatible capabilities will be rejected.
|
||||
This doesn't need to be defined normally, as there is a default.
|
||||
However, in some cases, defining it manually to control linkeability more granularly may be desirable.
|
||||
Value: A generic object, which is "directly usable".
|
||||
This should be chosen when a more specific flow kind doesn't apply.
|
||||
Array: An object with dimensions, and possibly a unit.
|
||||
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
|
||||
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
|
||||
LazyValueFunc: A composable function.
|
||||
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
|
||||
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
|
||||
This should be used instead of `Array` whenever possible.
|
||||
Param: An object providing data to complete `Lazy` data.
|
||||
For example,
|
||||
Info: An object providing context about other flows.
|
||||
For example,
|
||||
"""
|
||||
|
||||
Capabilities = enum.auto()
|
||||
|
||||
# Values
|
||||
Value = enum.auto()
|
||||
ValueArray = enum.auto()
|
||||
ValueSpectrum = enum.auto()
|
||||
Array = enum.auto()
|
||||
|
||||
# Lazy
|
||||
LazyValue = enum.auto()
|
||||
LazyValueRange = enum.auto()
|
||||
LazyValueSpectrum = enum.auto()
|
||||
LazyArrayRange = enum.auto()
|
||||
|
||||
# Auxiliary
|
||||
Param = enum.auto()
|
||||
Info = enum.auto()
|
||||
|
||||
@classmethod
|
||||
def scale_to_unit_system(cls, kind: typ.Self, value, socket_type, unit_system):
|
||||
|
@ -85,7 +66,7 @@ class DataFlowKind(enum.StrEnum):
|
|||
unit_system[socket_type],
|
||||
)
|
||||
)
|
||||
if kind == cls.LazyValueRange:
|
||||
if kind == cls.LazyArrayRange:
|
||||
return value.rescale_to_unit(unit_system[socket_type])
|
||||
|
||||
msg = 'Tried to scale unknown kind'
|
||||
|
@ -93,12 +74,12 @@ class DataFlowKind(enum.StrEnum):
|
|||
|
||||
|
||||
####################
|
||||
# - Data Structures: Capabilities
|
||||
# - Capabilities
|
||||
####################
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class DataCapabilities:
|
||||
class CapabilitiesFlow:
|
||||
socket_type: SocketType
|
||||
active_kind: DataFlowKind
|
||||
active_kind: FlowKind
|
||||
|
||||
is_universal: bool = False
|
||||
|
||||
|
@ -110,13 +91,16 @@ class DataCapabilities:
|
|||
|
||||
|
||||
####################
|
||||
# - Data Structures: Non-Lazy
|
||||
# - Value
|
||||
####################
|
||||
DataValue: typ.TypeAlias = typ.Any
|
||||
ValueFlow: typ.TypeAlias = typ.Any
|
||||
|
||||
|
||||
####################
|
||||
# - Value Array
|
||||
####################
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class DataValueArray:
|
||||
class ArrayFlow:
|
||||
"""A simple, flat array of values with an optionally-attached unit.
|
||||
|
||||
Attributes:
|
||||
|
@ -125,69 +109,105 @@ class DataValueArray:
|
|||
None if unitless.
|
||||
"""
|
||||
|
||||
values: typ.Sequence[DataValue]
|
||||
values: jax.Array
|
||||
unit: spu.Quantity | None
|
||||
|
||||
|
||||
####################
|
||||
# - Lazy Value Func
|
||||
####################
|
||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], ValueFlow]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class DataValueSpectrum:
|
||||
"""A numerical representation of a spectral distribution.
|
||||
class LazyValueFuncFlow:
|
||||
r"""Encapsulates a lazily evaluated data value as a composable function with bound and free arguments.
|
||||
|
||||
- **Bound Args**: Arguments that are realized when **defining** the lazy value.
|
||||
Both positional values and keyword values are supported.
|
||||
- **Free Args**: Arguments that are specified when evaluating the lazy value.
|
||||
Both positional values and keyword values are supported.
|
||||
|
||||
The **root function** is encapsulated using `from_function`, and must accept arguments in the following order:
|
||||
|
||||
$$
|
||||
f_0:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \underbrace{r_1, r_2, ...}_{\text{Free}}) \to \text{output}_0
|
||||
$$
|
||||
|
||||
Subsequent **composed functions** are encapsulated from the _root function_, and are created with `root_function.compose`.
|
||||
They must accept arguments in the following order:
|
||||
|
||||
$$
|
||||
f_k:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \text{output}_{k-1} ,\ \underbrace{r_p, r_{p+1}, ...}_{\text{Free}}) \to \text{output}_k
|
||||
$$
|
||||
|
||||
Attributes:
|
||||
wls: A 1D `numpy` float array of wavelength values.
|
||||
wls_unit: The unit of wavelengths, as length dimension.
|
||||
values: A 1D `numpy` float array of values corresponding to wavelength values.
|
||||
values_unit: The unit of the value, as arbitrary dimension.
|
||||
freqs_unit: The unit of the value, as arbitrary dimension.
|
||||
function: The function to be lazily evaluated.
|
||||
bound_args: Arguments that will be packaged into function, which can't be later modifier.
|
||||
func_kwargs: Arguments to be specified by the user at the time of use.
|
||||
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
|
||||
supports_numba: Whether the contained `self.function` can be compiled with Numba's JIT compiler.
|
||||
"""
|
||||
|
||||
# Wavelength
|
||||
wls: np.array
|
||||
wls_unit: spu.Quantity
|
||||
func: LazyFunction
|
||||
func_kwargs: dict[str, type]
|
||||
supports_jax: bool = False
|
||||
supports_numba: bool = False
|
||||
|
||||
# Value
|
||||
values: np.array
|
||||
values_unit: spu.Quantity
|
||||
@staticmethod
|
||||
def from_func(
|
||||
func: LazyFunction,
|
||||
supports_jax: bool = False,
|
||||
supports_numba: bool = False,
|
||||
**func_kwargs: dict[str, type],
|
||||
) -> typ.Self:
|
||||
return LazyValueFuncFlow(
|
||||
func=func,
|
||||
func_kwargs=func_kwargs,
|
||||
supports_jax=supports_jax,
|
||||
supports_numba=supports_numba,
|
||||
)
|
||||
|
||||
# Frequency
|
||||
freqs_unit: spu.Quantity = spu.hertz
|
||||
# Composition
|
||||
def compose_within(
|
||||
self,
|
||||
enclosing_func: LazyFunction,
|
||||
supports_jax: bool = False,
|
||||
supports_numba: bool = False,
|
||||
**enclosing_func_kwargs: dict[str, type],
|
||||
) -> typ.Self:
|
||||
return LazyValueFuncFlow(
|
||||
function=lambda **kwargs: enclosing_func(
|
||||
self.func(**{k: v for k, v in kwargs if k in self.func_kwargs}),
|
||||
**kwargs,
|
||||
),
|
||||
func_kwargs=self.func_kwargs | enclosing_func_kwargs,
|
||||
supports_jax=self.supports_jax and supports_jax,
|
||||
supports_numba=self.supports_numba and supports_numba,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def freqs(self) -> np.array:
|
||||
"""The spectral frequencies, computed from the wavelengths.
|
||||
def func_jax(self) -> LazyFunction:
|
||||
if self.supports_jax:
|
||||
return jax.jit(self.func)
|
||||
|
||||
Frequencies are NOT reversed, so as to preserve the by-index mapping to `DataValueSpectrum.values`.
|
||||
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
|
||||
raise ValueError(msg)
|
||||
|
||||
Returns:
|
||||
Frequencies, as a unitless `numpy` array.
|
||||
Use `DataValueSpectrum.wls_unit` to interpret this return value.
|
||||
"""
|
||||
unitless_speed_of_light = spux.sympy_to_python(
|
||||
spux.scale_to_unit(
|
||||
constants.vac_speed_of_light, (self.wl_unit / self.freq_unit)
|
||||
)
|
||||
)
|
||||
return unitless_speed_of_light / self.wls
|
||||
@functools.cached_property
|
||||
def func_numba(self) -> LazyFunction:
|
||||
if self.supports_numba:
|
||||
return numba.jit(self.func)
|
||||
|
||||
# TODO: Colour Library
|
||||
# def as_colour_sd(self) -> colour.SpectralDistribution:
|
||||
# """Returns the `colour` representation of this spectral distribution, ideal for plotting and colorimetric analysis."""
|
||||
# return colour.SpectralDistribution(data=self.values, domain=self.wls)
|
||||
msg = 'Can\'t express LazyValueFuncFlow as Numba function (using numba.jit), since "self.supports_numba" is False'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
####################
|
||||
# - Data Structures: Lazy
|
||||
# - Lazy Array Range
|
||||
####################
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class LazyDataValue:
|
||||
callback: typ.Callable[[...], [DataValue]]
|
||||
|
||||
def realize(self, *args: list[DataValue]) -> DataValue:
|
||||
return self.callback(*args)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class LazyDataValueRange:
|
||||
class LazyArrayRangeFlow:
|
||||
symbols: set[sp.Symbol]
|
||||
|
||||
start: sp.Basic
|
||||
|
@ -200,7 +220,7 @@ class LazyDataValueRange:
|
|||
|
||||
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
|
||||
if self.has_unit:
|
||||
return LazyDataValueRange(
|
||||
return LazyArrayRangeFlow(
|
||||
symbols=self.symbols,
|
||||
has_unit=self.has_unit,
|
||||
unit=unit,
|
||||
|
@ -219,7 +239,7 @@ class LazyDataValueRange:
|
|||
reverse: bool = False,
|
||||
) -> typ.Self:
|
||||
"""Call a function on both bounds (start and stop), creating a new `LazyDataValueRange`."""
|
||||
return LazyDataValueRange(
|
||||
return LazyArrayRangeFlow(
|
||||
symbols=self.symbols,
|
||||
has_unit=self.has_unit,
|
||||
unit=self.unit,
|
||||
|
@ -234,8 +254,8 @@ class LazyDataValueRange:
|
|||
)
|
||||
|
||||
def realize(
|
||||
self, symbol_values: dict[sp.Symbol, DataValue] = MappingProxyType({})
|
||||
) -> DataValueArray:
|
||||
self, symbol_values: dict[sp.Symbol, ValueFlow] = MappingProxyType({})
|
||||
) -> ArrayFlow:
|
||||
# Realize Symbols
|
||||
if not self.has_unit:
|
||||
start = spux.sympy_to_python(self.start.subs(symbol_values))
|
||||
|
@ -250,85 +270,25 @@ class LazyDataValueRange:
|
|||
|
||||
# Return Linspace / Logspace
|
||||
if self.scaling == 'lin':
|
||||
return DataValueArray(
|
||||
values=np.linspace(start, stop, self.steps), unit=self.unit
|
||||
return ArrayFlow(
|
||||
values=jnp.linspace(start, stop, self.steps), unit=self.unit
|
||||
)
|
||||
if self.scaling == 'geom':
|
||||
return DataValueArray(np.geomspace(start, stop, self.steps), self.unit)
|
||||
return ArrayFlow(jnp.geomspace(start, stop, self.steps), self.unit)
|
||||
if self.scaling == 'log':
|
||||
return DataValueArray(np.logspace(start, stop, self.steps), self.unit)
|
||||
return ArrayFlow(jnp.logspace(start, stop, self.steps), self.unit)
|
||||
|
||||
raise NotImplementedError
|
||||
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class LazyDataValueSpectrum:
|
||||
wl_unit: spu.Quantity
|
||||
value_unit: spu.Quantity
|
||||
value_expr: sp.Expr
|
||||
|
||||
symbols: tuple[sp.Symbol, ...] = ()
|
||||
freq_symbol: sp.Symbol = sp.Symbol('lamda') # noqa: RUF009
|
||||
|
||||
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@functools.cached_property
|
||||
def as_func(self) -> typ.Callable[[DataValue, ...], DataValue]:
|
||||
"""Generates an optimized function for numerical evaluation of the spectral expression."""
|
||||
return sp.lambdify([self.freq_symbol, *self.symbols], self.value_expr)
|
||||
|
||||
def realize(
|
||||
self, wl_range: DataValueArray, symbol_values: tuple[DataValue, ...]
|
||||
) -> DataValueSpectrum:
|
||||
r"""Realizes the parameterized spectral function as a numerical spectral distribution.
|
||||
|
||||
Parameters:
|
||||
wl_range: The lazy wavelength range to build the concrete spectral distribution with.
|
||||
symbol_values: Numerical values for each symbol, in the same order as defined in `LazyDataValueSpectrum.symbols`.
|
||||
The wavelength symbol ($\lambda$ by default) always goes first.
|
||||
_This is used to call the spectral function using the output of `.as_func()`._
|
||||
|
||||
Returns:
|
||||
The concrete, numerical spectral distribution.
|
||||
"""
|
||||
return DataValueSpectrum(
|
||||
wls=wl_range.values,
|
||||
wls_unit=self.wl_unit,
|
||||
values=self.as_func(*list(symbol_values.values())),
|
||||
values_unit=self.value_unit,
|
||||
)
|
||||
####################
|
||||
# - Param
|
||||
####################
|
||||
ParamFlow: typ.TypeAlias = dict[str, typ.Any]
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
#####################
|
||||
## - Data Pipeline
|
||||
#####################
|
||||
# @dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
# class DataPipelineDim:
|
||||
# unit: spu.Quantity | None
|
||||
#
|
||||
# class DataPipelineDimType(enum.StrEnum):
|
||||
# # Map Inputs
|
||||
# Time = enum.auto()
|
||||
# Freq = enum.auto()
|
||||
# Space3D = enum.auto()
|
||||
# DiffOrder = enum.auto()
|
||||
#
|
||||
# # Map Inputs
|
||||
# Power = enum.auto()
|
||||
# EVec = enum.auto()
|
||||
# HVec = enum.auto()
|
||||
# RelPerm = enum.auto()
|
||||
#
|
||||
#
|
||||
# @dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
# class LazyDataPipeline:
|
||||
# dims: list[DataPipelineDim]
|
||||
#
|
||||
# def _callable(self):
|
||||
# """JITs the current pipeline of functions with `jax`."""
|
||||
#
|
||||
# def __call__(self):
|
||||
# pass
|
||||
####################
|
||||
# - Lazy Value Func
|
||||
####################
|
||||
InfoFlow: typ.TypeAlias = dict[str, typ.Any]
|
||||
|
|
|
@ -7,6 +7,7 @@ SOCKET_COLORS = {
|
|||
ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
|
||||
ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
|
||||
ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey
|
||||
ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey
|
||||
# Number
|
||||
ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue
|
||||
ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue
|
||||
|
|
|
@ -6,6 +6,7 @@ SOCKET_SHAPES = {
|
|||
ST.Bool: 'CIRCLE',
|
||||
ST.String: 'CIRCLE',
|
||||
ST.FilePath: 'CIRCLE',
|
||||
ST.Expr: 'CIRCLE',
|
||||
# Number
|
||||
ST.IntegerNumber: 'CIRCLE',
|
||||
ST.RationalNumber: 'CIRCLE',
|
||||
|
|
|
@ -14,6 +14,7 @@ class SocketType(BlenderTypeEnum):
|
|||
String = enum.auto()
|
||||
FilePath = enum.auto()
|
||||
Color = enum.auto()
|
||||
Expr = enum.auto()
|
||||
|
||||
# Number
|
||||
IntegerNumber = enum.auto()
|
||||
|
|
|
@ -38,8 +38,8 @@ def apply_colormap(normalized_data, colormap):
|
|||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_xyzf__viridis(xyz_freq):
|
||||
amplitude = jnp.abs(jnp.squeeze(xyz_freq))
|
||||
def rgba_image_from_2d_map__viridis(map_2d):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
|
@ -49,8 +49,8 @@ def rgba_image_from_xyzf__viridis(xyz_freq):
|
|||
|
||||
|
||||
@jax.jit
|
||||
def rgba_image_from_xyzf__grayscale(xyz_freq):
|
||||
amplitude = jnp.abs(jnp.squeeze(xyz_freq))
|
||||
def rgba_image_from_2d_map__grayscale(map_2d):
|
||||
amplitude = jnp.abs(map_2d)
|
||||
amplitude_normalized = (amplitude - amplitude.min()) / (
|
||||
amplitude.max() - amplitude.min()
|
||||
)
|
||||
|
@ -59,21 +59,19 @@ def rgba_image_from_xyzf__grayscale(xyz_freq):
|
|||
return jnp.dstack((rgb_array, alpha_channel))
|
||||
|
||||
|
||||
def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None):
|
||||
"""RGBA Image from Squeezable XYZ-Freq w/fixed freq.
|
||||
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
|
||||
"""RGBA Image from a map of 2D coordinates to values.
|
||||
|
||||
Parameters:
|
||||
xyz_freq: Shape (xlen, ylen, zlen), one dimension has length 1.
|
||||
width_px: Pixel width to resize the image to.
|
||||
height: Pixel height to resize the image to.
|
||||
map_2d: Shape (width, height, value).
|
||||
|
||||
Returns:
|
||||
Image as a JAX array of shape (height, width, 3)
|
||||
Image as a JAX array of shape (height, width, 4)
|
||||
"""
|
||||
if colormap == 'VIRIDIS':
|
||||
return rgba_image_from_xyzf__viridis(xyz_freq)
|
||||
return rgba_image_from_2d_map__viridis(map_2d)
|
||||
if colormap == 'GRAYSCALE':
|
||||
return rgba_image_from_xyzf__grayscale(xyz_freq)
|
||||
return rgba_image_from_2d_map__grayscale(map_2d)
|
||||
|
||||
|
||||
class ManagedBLImage(base.ManagedObj):
|
||||
|
@ -227,11 +225,11 @@ class ManagedBLImage(base.ManagedObj):
|
|||
####################
|
||||
# - Special Methods
|
||||
####################
|
||||
def xyzf_to_image(
|
||||
self, xyz_freq, colormap: str | None = 'VIRIDIS', bl_select: bool = False
|
||||
def map_2d_to_image(
|
||||
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
|
||||
):
|
||||
self.data_to_image(
|
||||
lambda _: rgba_image_from_xyzf(xyz_freq, colormap=colormap),
|
||||
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap),
|
||||
bl_select=bl_select,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from . import extract_data, viz
|
||||
from . import extract_data, viz, math
|
||||
|
||||
BL_REGISTER = [
|
||||
*extract_data.BL_REGISTER,
|
||||
*viz.BL_REGISTER,
|
||||
*math.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**extract_data.BL_NODES,
|
||||
**viz.BL_NODES,
|
||||
**math.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import jarray, logger
|
||||
|
||||
from .....utils import logger
|
||||
from ... import contracts as ct
|
||||
from ... import sockets
|
||||
from .. import base, events
|
||||
|
@ -229,8 +232,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
@events.computes_output_socket(
|
||||
'Data',
|
||||
props={'sim_data__monitor_name', 'field_data__component'},
|
||||
input_sockets={'Field Data'},
|
||||
input_sockets_optional={'Field Data': True},
|
||||
)
|
||||
def compute_extracted_data(self, props: dict):
|
||||
def compute_extracted_data(self, props: dict, input_sockets: dict):
|
||||
if self.active_socket_set == 'Sim Data':
|
||||
if (
|
||||
CACHE_SIM_DATA.get(self.instance_id) is None
|
||||
|
@ -242,12 +247,21 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
return sim_data.monitor_data[props['sim_data__monitor_name']]
|
||||
|
||||
elif self.active_socket_set == 'Field Data': # noqa: RET505
|
||||
field_data = self._compute_input('Field Data')
|
||||
return getattr(field_data, props['field_data__component'])
|
||||
xarr = getattr(input_sockets['Field Data'], props['field_data__component'])
|
||||
|
||||
return jarray.JArray.from_xarray(
|
||||
xarr,
|
||||
dim_units={
|
||||
'x': spu.um,
|
||||
'y': spu.um,
|
||||
'z': spu.um,
|
||||
'f': spu.hertz,
|
||||
},
|
||||
)
|
||||
|
||||
elif self.active_socket_set == 'Flux Data':
|
||||
flux_data = self._compute_input('Flux Data')
|
||||
return flux_data.flux
|
||||
return jnp.array(flux_data.flux)
|
||||
|
||||
msg = f'Tried to get data from unknown output socket in "{self.bl_label}"'
|
||||
raise RuntimeError(msg)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
from . import map_math, filter_math, reduce_math, operate_math
|
||||
|
||||
BL_REGISTER = [
|
||||
*map_math.BL_REGISTER,
|
||||
*filter_math.BL_REGISTER,
|
||||
*reduce_math.BL_REGISTER,
|
||||
*operate_math.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**map_math.BL_NODES,
|
||||
**filter_math.BL_NODES,
|
||||
**reduce_math.BL_NODES,
|
||||
**operate_math.BL_NODES,
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
import functools
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
# @functools.partial(jax.jit, static_argnames=('fixed_axis', 'fixed_axis_value'))
|
||||
# jax.jit
|
||||
def fix_axis(data, fixed_axis: int, fixed_axis_value: float):
|
||||
log.critical(data.shape)
|
||||
# Select Values of Fixed Axis
|
||||
fixed_axis_values = data[
|
||||
tuple(slice(None) if i == fixed_axis else 0 for i in range(data.ndim))
|
||||
]
|
||||
log.critical(fixed_axis_values)
|
||||
|
||||
# Compute Nearest Index on Fixed Axis
|
||||
idx_of_nearest = jnp.argmin(jnp.abs(fixed_axis_values - fixed_axis_value))
|
||||
log.critical(idx_of_nearest)
|
||||
|
||||
# Select Values along Fixed Axis Value
|
||||
return jnp.take(data, idx_of_nearest, axis=fixed_axis)
|
||||
|
||||
|
||||
class FilterMathNode(base.MaxwellSimNode):
|
||||
node_type = ct.NodeType.FilterMath
|
||||
bl_label = 'Filter Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'By Axis Value': {
|
||||
'Axis': sockets.IntegerNumberSocketDef(),
|
||||
'Value': sockets.RealNumberSocketDef(),
|
||||
},
|
||||
'By Axis': {
|
||||
'Axis': sockets.IntegerNumberSocketDef(),
|
||||
},
|
||||
## TODO: bool arrays for comparison/switching/sparse 0-setting/etc. .
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: bpy.props.EnumProperty(
|
||||
name='Op',
|
||||
description='Operation to reduce the input axis with',
|
||||
items=lambda self, _: self.search_operations(),
|
||||
update=lambda self, context: self.sync_prop('operation', context),
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
items = []
|
||||
if self.active_socket_set == 'By Axis Value':
|
||||
items += [
|
||||
('FIX', 'Fix Coordinate', '(*, N, *) -> (*, *)'),
|
||||
]
|
||||
if self.active_socket_set == 'By Axis':
|
||||
items += [
|
||||
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
|
||||
]
|
||||
else:
|
||||
items += [('NONE', 'None', 'No operations...')]
|
||||
|
||||
return items
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
if self.active_socket_set != 'Axis Expr':
|
||||
layout.prop(self, 'operation')
|
||||
|
||||
####################
|
||||
# - Compute
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
props={'operation', 'active_socket_set'},
|
||||
input_sockets={'Data', 'Axis', 'Value'},
|
||||
input_sockets_optional={'Axis': True, 'Value': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
if not hasattr(input_sockets['Data'], 'shape'):
|
||||
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
|
||||
raise ValueError(msg)
|
||||
|
||||
# By Axis Value
|
||||
if props['active_socket_set'] == 'By Axis Value':
|
||||
if props['operation'] == 'FIX':
|
||||
return fix_axis(
|
||||
input_sockets['Data'], input_sockets['Axis'], input_sockets['Value']
|
||||
)
|
||||
|
||||
# By Axis
|
||||
if props['active_socket_set'] == 'By Axis':
|
||||
if props['operation'] == 'SQUEEZE':
|
||||
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
|
||||
msg = 'Operation invalid'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
FilterMathNode,
|
||||
]
|
||||
BL_NODES = {ct.NodeType.FilterMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -0,0 +1,164 @@
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class MapMathNode(base.MaxwellSimNode):
|
||||
node_type = ct.NodeType.MapMath
|
||||
bl_label = 'Map Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'By Element': {},
|
||||
'By Vector': {},
|
||||
'By Matrix': {},
|
||||
'Expr': {
|
||||
'Mapper': sockets.ExprSocketDef(
|
||||
symbols=[sp.Symbol('x')],
|
||||
default_expr=sp.Symbol('x'),
|
||||
),
|
||||
},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: bpy.props.EnumProperty(
|
||||
name='Op',
|
||||
description='Operation to apply to the input',
|
||||
items=lambda self, _: self.search_operations(),
|
||||
update=lambda self, context: self.sync_prop('operation', context),
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
items = []
|
||||
if self.active_socket_set == 'By Element':
|
||||
items += [
|
||||
# General
|
||||
('REAL', 'real', 'ℝ(L) (by el)'),
|
||||
('IMAG', 'imag', 'Im(L) (by el)'),
|
||||
('ABS', 'abs', '|L| (by el)'),
|
||||
('SQ', 'square', 'L^2 (by el)'),
|
||||
('SQRT', 'sqrt', 'sqrt(L) (by el)'),
|
||||
('INV_SQRT', '1/sqrt', '1/sqrt(L) (by el)'),
|
||||
# Trigonometry
|
||||
('COS', 'cos', 'cos(L) (by el)'),
|
||||
('SIN', 'sin', 'sin(L) (by el)'),
|
||||
('TAN', 'tan', 'tan(L) (by el)'),
|
||||
('ACOS', 'acos', 'acos(L) (by el)'),
|
||||
('ASIN', 'asin', 'asin(L) (by el)'),
|
||||
('ATAN', 'atan', 'atan(L) (by el)'),
|
||||
]
|
||||
elif self.active_socket_set in 'By Vector':
|
||||
items += [
|
||||
('NORM_2', '2-Norm', '||L||_2 (by Vec)'),
|
||||
]
|
||||
elif self.active_socket_set == 'By Matrix':
|
||||
items += [
|
||||
# Matrix -> Number
|
||||
('DET', 'Determinant', 'det(L) (by Mat)'),
|
||||
('COND', 'Condition', 'κ(L) (by Mat)'),
|
||||
('NORM_FRO', 'Frobenius Norm', '||L||_F (by Mat)'),
|
||||
('RANK', 'Rank', 'rank(L) (by Mat)'),
|
||||
# Matrix -> Array
|
||||
('DIAG', 'Diagonal', 'diag(L) (by Mat)'),
|
||||
('EIG_VALS', 'Eigenvalues', 'eigvals(L) (by Mat)'),
|
||||
('SVD_VALS', 'SVD', 'svd(L) -> diag(Σ) (by Mat)'),
|
||||
# Matrix -> Matrix
|
||||
('INV', 'Invert', 'L^(-1) (by Mat)'),
|
||||
('TRA', 'Transpose', 'L^T (by Mat)'),
|
||||
# Matrix -> Matrices
|
||||
('QR', 'QR', 'L -> Q·R (by Mat)'),
|
||||
('CHOL', 'Cholesky', 'L -> L·Lh (by Mat)'),
|
||||
('SVD', 'SVD', 'L -> U·Σ·Vh (by Mat)'),
|
||||
]
|
||||
else:
|
||||
items += ['EXPR_EL', 'Expr (by el)', 'Expression-defined (by el)']
|
||||
return items
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
if self.active_socket_set not in {'Expr (Element)'}:
|
||||
layout.prop(self, 'operation')
|
||||
|
||||
####################
|
||||
# - Compute
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data', 'Mapper'},
|
||||
input_socket_kinds={'Mapper': ct.DataFlowKind.LazyValue},
|
||||
input_sockets_optional={'Mapper': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
|
||||
'By Element': {
|
||||
'REAL': lambda data: jnp.real(data),
|
||||
'IMAG': lambda data: jnp.imag(data),
|
||||
'ABS': lambda data: jnp.abs(data),
|
||||
'SQ': lambda data: jnp.square(data),
|
||||
'SQRT': lambda data: jnp.sqrt(data),
|
||||
'INV_SQRT': lambda data: 1 / jnp.sqrt(data),
|
||||
'COS': lambda data: jnp.cos(data),
|
||||
'SIN': lambda data: jnp.sin(data),
|
||||
'TAN': lambda data: jnp.tan(data),
|
||||
'ACOS': lambda data: jnp.acos(data),
|
||||
'ASIN': lambda data: jnp.asin(data),
|
||||
'ATAN': lambda data: jnp.atan(data),
|
||||
'SINC': lambda data: jnp.sinc(data),
|
||||
},
|
||||
'By Vector': {
|
||||
'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1),
|
||||
},
|
||||
'By Matrix': {
|
||||
# Matrix -> Number
|
||||
'DET': lambda data: jnp.linalg.det(data),
|
||||
'COND': lambda data: jnp.linalg.cond(data),
|
||||
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
||||
'RANK': lambda data: jnp.linalg.matrix_rank(data),
|
||||
# Matrix -> Vec
|
||||
'DIAG': lambda data: jnp.diag(data),
|
||||
'EIG_VALS': lambda data: jnp.eigvals(data),
|
||||
'SVD_VALS': lambda data: jnp.svdvals(data),
|
||||
# Matrix -> Matrix
|
||||
'INV': lambda data: jnp.inv(data),
|
||||
'TRA': lambda data: jnp.matrix_transpose(data),
|
||||
# Matrix -> Matrices
|
||||
'QR': lambda data: jnp.inv(data),
|
||||
'CHOL': lambda data: jnp.linalg.cholesky(data),
|
||||
'SVD': lambda data: jnp.linalg.svd(data),
|
||||
},
|
||||
'By El (Expr)': {
|
||||
'EXPR_EL': lambda data: input_sockets['Mapper'](data),
|
||||
},
|
||||
}[props['active_socket_set']][props['operation']]
|
||||
|
||||
# Compose w/Lazy Root Function Data
|
||||
return input_sockets['Data'].compose(
|
||||
function=mapping_func,
|
||||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
MapMathNode,
|
||||
]
|
||||
BL_NODES = {ct.NodeType.MapMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -0,0 +1,138 @@
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class OperateMathNode(base.MaxwellSimNode):
|
||||
node_type = ct.NodeType.OperateMath
|
||||
bl_label = 'Operate Math'
|
||||
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Elementwise': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
},
|
||||
## TODO: Filter-array building operations
|
||||
'Vec-Vec': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
},
|
||||
'Mat-Vec': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: bpy.props.EnumProperty(
|
||||
name='Op',
|
||||
description='Operation to apply to the two inputs',
|
||||
items=lambda self, _: self.search_operations(),
|
||||
update=lambda self, context: self.sync_prop('operation', context),
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
items = []
|
||||
if self.active_socket_set == 'Elementwise':
|
||||
items = [
|
||||
('ADD', 'Add', 'L + R (by el)'),
|
||||
('SUB', 'Subtract', 'L - R (by el)'),
|
||||
('MUL', 'Multiply', 'L · R (by el)'),
|
||||
('DIV', 'Divide', 'L ÷ R (by el)'),
|
||||
('POW', 'Power', 'L^R (by el)'),
|
||||
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
|
||||
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
|
||||
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
|
||||
]
|
||||
elif self.active_socket_set in 'Vec | Vec':
|
||||
items = [
|
||||
('DOT', 'Dot', 'L · R'),
|
||||
('CROSS', 'Cross', 'L x R (by last-axis'),
|
||||
]
|
||||
elif self.active_socket_set == 'Mat | Vec':
|
||||
items = [
|
||||
('DOT', 'Dot', 'L · R'),
|
||||
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
|
||||
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
|
||||
]
|
||||
return items
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, 'operation')
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
props={'operation'},
|
||||
input_sockets={'Data L', 'Data R'},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
if self.active_socket_set == 'Elementwise':
|
||||
# Element-Wise Arithmetic
|
||||
if props['operation'] == 'ADD':
|
||||
return input_sockets['Data L'] + input_sockets['Data R']
|
||||
if props['operation'] == 'SUB':
|
||||
return input_sockets['Data L'] - input_sockets['Data R']
|
||||
if props['operation'] == 'MUL':
|
||||
return input_sockets['Data L'] * input_sockets['Data R']
|
||||
if props['operation'] == 'DIV':
|
||||
return input_sockets['Data L'] / input_sockets['Data R']
|
||||
|
||||
# Element-Wise Arithmetic
|
||||
if props['operation'] == 'POW':
|
||||
return input_sockets['Data L'] ** input_sockets['Data R']
|
||||
|
||||
# Binary Trigonometry
|
||||
if props['operation'] == 'ATAN2':
|
||||
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
|
||||
|
||||
# Special Functions
|
||||
if props['operation'] == 'HEAVISIDE':
|
||||
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
|
||||
|
||||
# Linear Algebra
|
||||
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
|
||||
if props['operation'] == 'DOT':
|
||||
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
|
||||
|
||||
elif self.active_socket_set == 'Vec-Vec':
|
||||
if props['operation'] == 'CROSS':
|
||||
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
|
||||
|
||||
elif self.active_socket_set == 'Mat-Vec':
|
||||
if props['operation'] == 'LIN_SOLVE':
|
||||
return jnp.linalg.lstsq(
|
||||
input_sockets['Data L'], input_sockets['Data R']
|
||||
)
|
||||
if props['operation'] == 'LSQ_SOLVE':
|
||||
return jnp.linalg.solve(
|
||||
input_sockets['Data L'], input_sockets['Data R']
|
||||
)
|
||||
|
||||
msg = 'Invalid operation'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
OperateMathNode,
|
||||
]
|
||||
BL_NODES = {ct.NodeType.OperateMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -0,0 +1,135 @@
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class ReduceMathNode(base.MaxwellSimNode):
|
||||
node_type = ct.NodeType.ReduceMath
|
||||
bl_label = 'Reduce Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
'Axis': sockets.IntegerNumberSocketDef(),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'By Axis': {
|
||||
'Axis': sockets.IntegerNumberSocketDef(),
|
||||
},
|
||||
'Expr': {
|
||||
'Reducer': sockets.ExprSocketDef(
|
||||
symbols=[sp.Symbol('a'), sp.Symbol('b')],
|
||||
default_expr=sp.Symbol('a') + sp.Symbol('b'),
|
||||
),
|
||||
'Axis': sockets.IntegerNumberSocketDef(),
|
||||
},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: bpy.props.EnumProperty(
|
||||
name='Op',
|
||||
description='Operation to reduce the input axis with',
|
||||
items=lambda self, _: self.search_operations(),
|
||||
update=lambda self, context: self.sync_prop('operation', context),
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
items = []
|
||||
if self.active_socket_set == 'By Axis':
|
||||
items += [
|
||||
# Accumulation
|
||||
('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'),
|
||||
('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'),
|
||||
('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'),
|
||||
('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'),
|
||||
('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'),
|
||||
# Stats
|
||||
('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'),
|
||||
('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'),
|
||||
('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'),
|
||||
('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'),
|
||||
# Dimension Reduction
|
||||
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
|
||||
]
|
||||
else:
|
||||
items += [('NONE', 'None', 'No operations...')]
|
||||
|
||||
return items
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
if self.active_socket_set != 'Axis Expr':
|
||||
layout.prop(self, 'operation')
|
||||
|
||||
####################
|
||||
# - Compute
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
props={'operation'},
|
||||
input_sockets={'Data', 'Axis', 'Reducer'},
|
||||
input_socket_kinds={'Reducer': ct.DataFlowKind.LazyValue},
|
||||
input_sockets_optional={'Reducer': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
if not hasattr(input_sockets['Data'], 'shape'):
|
||||
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.active_socket_set == 'Axis Expr':
|
||||
ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1)
|
||||
return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
|
||||
if self.active_socket_set == 'By Axis':
|
||||
## Dimension Reduction
|
||||
# ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
|
||||
# Accumulation
|
||||
if props['operation'] == 'SUM':
|
||||
return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'PROD':
|
||||
return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'MIN':
|
||||
return jnp.min(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'MAX':
|
||||
return jnp.max(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'P2P':
|
||||
return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
|
||||
# Stats
|
||||
if props['operation'] == 'MEAN':
|
||||
return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'MEDIAN':
|
||||
return jnp.median(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'STDDEV':
|
||||
return jnp.std(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
if props['operation'] == 'VARIANCE':
|
||||
return jnp.var(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
|
||||
# Dimension Reduction
|
||||
if props['operation'] == 'SQUEEZE':
|
||||
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
|
||||
|
||||
msg = 'Operation invalid'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
ReduceMathNode,
|
||||
]
|
||||
BL_NODES = {ct.NodeType.ReduceMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -20,7 +20,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Sockets
|
||||
####################
|
||||
input_sockets = {
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
'Freq': sockets.PhysicalFreqSocketDef(),
|
||||
}
|
||||
|
@ -72,20 +72,12 @@ class VizNode(base.MaxwellSimNode):
|
|||
props: dict,
|
||||
unit_systems: dict,
|
||||
):
|
||||
selected_data = jnp.array(
|
||||
input_sockets['Data'].sel(f=input_sockets['Freq'], method='nearest')
|
||||
)
|
||||
|
||||
managed_objs['plot'].xyzf_to_image(
|
||||
selected_data,
|
||||
managed_objs['plot'].map_2d_to_image(
|
||||
input_sockets['Data'].as_bound_jax_func(),
|
||||
colormap=props['colormap'],
|
||||
bl_select=True,
|
||||
)
|
||||
|
||||
# @events.on_init()
|
||||
# def on_init(self):
|
||||
# self.on_changed_inputs()
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
|
|
|
@ -42,6 +42,15 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
|||
# - SocketDef
|
||||
####################
|
||||
class MaxwellSimSocket(bpy.types.NodeSocket):
|
||||
"""A specialized Blender socket for nodes in a Maxwell simulation.
|
||||
|
||||
Attributes:
|
||||
instance_id: A unique ID attached to a particular socket instance.
|
||||
Guaranteed to be unchanged so long as the socket lives.
|
||||
Used as a socket-specific cache index.
|
||||
locked: The lock-state of a particular socket, which determines the socket's user editability
|
||||
"""
|
||||
|
||||
# Fundamentals
|
||||
socket_type: ct.SocketType
|
||||
bl_label: str
|
||||
|
@ -73,21 +82,53 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
|
|||
####################
|
||||
# - Initialization
|
||||
####################
|
||||
def __init_subclass__(cls, **kwargs: typ.Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
@classmethod
|
||||
def set_prop(
|
||||
cls,
|
||||
prop_name: str,
|
||||
prop: bpy.types.Property,
|
||||
no_update: bool = False,
|
||||
update_with_name: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Adds a Blender property to a class via `__annotations__`, so it initializes with any subclass.
|
||||
|
||||
# Setup Blender ID for Node
|
||||
if not hasattr(cls, 'socket_type'):
|
||||
msg = f"Socket class {cls} does not define 'socket_type'"
|
||||
raise ValueError(msg)
|
||||
cls.bl_idname = str(cls.socket_type.value)
|
||||
Notes:
|
||||
- Blender properties can't be set within `__init_subclass__` simply by adding attributes to the class; they must be added as type annotations.
|
||||
- Must be called **within** `__init_subclass__`.
|
||||
|
||||
# Setup Locked Property for Node
|
||||
cls.__annotations__['locked'] = bpy.props.BoolProperty(
|
||||
name='Locked State',
|
||||
description="The lock-state of a particular socket, which determines the socket's user editability",
|
||||
default=False,
|
||||
Parameters:
|
||||
name: The name of the property to set.
|
||||
prop: The `bpy.types.Property` to instantiate and attach..
|
||||
no_update: Don't attach a `self.sync_prop()` callback to the property's `update`.
|
||||
"""
|
||||
_update_with_name = prop_name if update_with_name is None else update_with_name
|
||||
extra_kwargs = (
|
||||
{
|
||||
'update': lambda self, context: self.sync_prop(
|
||||
_update_with_name, context
|
||||
),
|
||||
}
|
||||
if not no_update
|
||||
else {}
|
||||
)
|
||||
cls.__annotations__[prop_name] = prop(
|
||||
**kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
def __init_subclass__(cls, **kwargs: typ.Any):
|
||||
log.debug('Initializing Socket: %s', cls.socket_type)
|
||||
super().__init_subclass__(**kwargs)
|
||||
# cls._assert_attrs_valid()
|
||||
|
||||
# Socket Properties
|
||||
## Identifiers
|
||||
cls.bl_idname: str = str(cls.socket_type.value)
|
||||
cls.set_prop('instance_id', bpy.props.StringProperty, no_update=True)
|
||||
|
||||
## Special States
|
||||
cls.set_prop('locked', bpy.props.BoolProperty, no_update=True, default=False)
|
||||
|
||||
# Setup Style
|
||||
cls.socket_color = ct.SOCKET_COLORS[cls.socket_type]
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from . import any as any_socket
|
||||
from . import bool as bool_socket
|
||||
from . import file_path, string
|
||||
from . import expr, file_path, string
|
||||
|
||||
AnySocketDef = any_socket.AnySocketDef
|
||||
BoolSocketDef = bool_socket.BoolSocketDef
|
||||
FilePathSocketDef = file_path.FilePathSocketDef
|
||||
StringSocketDef = string.StringSocketDef
|
||||
FilePathSocketDef = file_path.FilePathSocketDef
|
||||
ExprSocketDef = expr.ExprSocketDef
|
||||
|
||||
|
||||
BL_REGISTER = [
|
||||
|
@ -13,4 +14,5 @@ BL_REGISTER = [
|
|||
*bool_socket.BL_REGISTER,
|
||||
*string.BL_REGISTER,
|
||||
*file_path.BL_REGISTER,
|
||||
*expr.BL_REGISTER,
|
||||
]
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import bpy
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils.pydantic_sympy import SympyExpr
|
||||
|
||||
from ... import bl_cache
|
||||
from ... import contracts as ct
|
||||
from .. import base
|
||||
|
||||
|
||||
class ExprBLSocket(base.MaxwellSimSocket):
|
||||
socket_type = ct.SocketType.Expr
|
||||
bl_label = 'Expr'
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
raw_value: bpy.props.StringProperty(
|
||||
name='Expr',
|
||||
description='Represents a symbolic expression',
|
||||
default='',
|
||||
update=(lambda self, context: self.sync_prop('raw_value', context)),
|
||||
)
|
||||
|
||||
symbols: list[sp.Symbol] = bl_cache.BLField([])
|
||||
## TODO: Way of assigning assumptions to symbols.
|
||||
## TODO: Dynamic add/remove of symbols
|
||||
|
||||
####################
|
||||
# - Socket UI
|
||||
####################
|
||||
def draw_value(self, col: bpy.types.UILayout) -> None:
|
||||
col.prop(self, 'raw_value', text='')
|
||||
|
||||
####################
|
||||
# - Computation of Default Value
|
||||
####################
|
||||
@property
|
||||
def value(self) -> sp.Expr:
|
||||
return sp.sympify(
|
||||
self.raw_value,
|
||||
strict=False,
|
||||
convert_xor=True,
|
||||
).subs(spux.ALL_UNIT_SYMBOLS)
|
||||
|
||||
@value.setter
|
||||
def value(self, value: str) -> None:
|
||||
self.raw_value = str(value)
|
||||
|
||||
@property
|
||||
def lazy_value(self) -> sp.Expr:
|
||||
return ct.LazyDataValue.from_function(
|
||||
sp.lambdify(self.symbols, self.value, 'jax'),
|
||||
free_args=(tuple(str(sym) for sym in self.symbols), frozenset()),
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Socket Configuration
|
||||
####################
|
||||
class ExprSocketDef(base.SocketDef):
|
||||
socket_type: ct.SocketType = ct.SocketType.Expr
|
||||
|
||||
_x = sp.Symbol('x', real=True)
|
||||
symbols: list[SympyExpr] = [_x]
|
||||
default_expr: SympyExpr = _x
|
||||
|
||||
def init(self, bl_socket: ExprBLSocket) -> None:
|
||||
bl_socket.value = self.default_expr
|
||||
bl_socket.symbols = self.symbols
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
ExprBLSocket,
|
||||
]
|
|
@ -57,3 +57,6 @@ class StringSocketDef(base.SocketDef):
|
|||
BL_REGISTER = [
|
||||
StringBLSocket,
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import dataclasses
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
|
||||
# import jaxtyping as jtyp
|
||||
import sympy.physics.units as spu
|
||||
import xarray
|
||||
|
||||
from . import extra_sympy_units as spux
|
||||
from . import logger
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
DimName: typ.TypeAlias = str
|
||||
Number: typ.TypeAlias = int | float | complex
|
||||
NumberRange: typ.TypeAlias = jax.Array
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True)
|
||||
class JArray:
|
||||
"""Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds."""
|
||||
|
||||
array: jax.Array
|
||||
dims: dict[DimName, NumberRange]
|
||||
dim_units: dict[DimName, spu.Quantity]
|
||||
|
||||
####################
|
||||
# - Constructor
|
||||
####################
|
||||
@classmethod
|
||||
def from_xarray(
|
||||
cls,
|
||||
xarr: xarray.DataArray,
|
||||
dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}),
|
||||
sort_axis: int = -1,
|
||||
) -> typ.Self:
|
||||
return cls(
|
||||
array=jnp.sort(jnp.array(xarr.data), axis=sort_axis),
|
||||
dims={
|
||||
dim_name: jnp.array(xarr.get_index(dim_name).values)
|
||||
for dim_name in xarr.dims
|
||||
},
|
||||
dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims},
|
||||
)
|
||||
|
||||
def idx(self, dim_name: DimName, dim_value: Number) -> int:
|
||||
found_idx = jnp.searchsorted(self.dims[dim_name], dim_value)
|
||||
if found_idx == 0:
|
||||
return found_idx
|
||||
if found_idx == len(self.dims[dim_name]):
|
||||
return found_idx - 1
|
||||
|
||||
left = self.dims[dim_name][found_idx - 1]
|
||||
right = self.dims[dim_name][found_idx - 1]
|
||||
return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx
|
||||
|
||||
@property
|
||||
def dtype(self) -> jnp.dtype:
|
||||
return self.array.dtype
|
Loading…
Reference in New Issue