feat: Math nodes (non-working)

main
Sofus Albert Høgsbro Rose 2024-04-17 16:03:15 +02:00
parent 568fc449e8
commit dfeb65feec
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
21 changed files with 974 additions and 226 deletions

View File

@ -14,6 +14,10 @@ dependencies = [
"networkx==3.2.*",
"rich==12.5.*",
"rtree==1.2.*",
"jax[cpu]==0.4.26",
"msgspec[toml]==0.18.6",
"numba==0.59.1",
"jaxtyping==0.2.28",
# Pin Blender 4.1.0-Compatible Versions
## The dependency resolver will report if anything is wonky.
"urllib3==1.26.8",
@ -22,8 +26,6 @@ dependencies = [
"idna==3.3",
"charset-normalizer==2.0.10",
"certifi==2021.10.8",
"jax[cpu]>=0.4.26",
"msgspec[toml]>=0.18.6",
]
readme = "README.md"
requires-python = "~= 3.11"

View File

@ -49,11 +49,14 @@ importlib-metadata==6.11.0
jax==0.4.26
jaxlib==0.4.26
# via jax
jaxtyping==0.2.28
jmespath==1.0.1
# via boto3
# via botocore
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
matplotlib==3.8.3
@ -65,13 +68,16 @@ mpmath==1.3.0
# via sympy
msgspec==0.18.6
networkx==3.2
numba==0.59.1
numpy==1.24.3
# via contourpy
# via h5py
# via jax
# via jaxlib
# via jaxtyping
# via matplotlib
# via ml-dtypes
# via numba
# via opt-einsum
# via scipy
# via shapely
@ -142,6 +148,8 @@ toolz==0.12.1
# via dask
# via partd
trimesh==4.2.0
typeguard==2.13.3
# via jaxtyping
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0

View File

@ -48,11 +48,14 @@ importlib-metadata==6.11.0
jax==0.4.26
jaxlib==0.4.26
# via jax
jaxtyping==0.2.28
jmespath==1.0.1
# via boto3
# via botocore
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
matplotlib==3.8.3
@ -64,13 +67,16 @@ mpmath==1.3.0
# via sympy
msgspec==0.18.6
networkx==3.2
numba==0.59.1
numpy==1.24.3
# via contourpy
# via h5py
# via jax
# via jaxlib
# via jaxtyping
# via matplotlib
# via ml-dtypes
# via numba
# via opt-einsum
# via scipy
# via shapely
@ -140,6 +146,8 @@ toolz==0.12.1
# via dask
# via partd
trimesh==4.2.0
typeguard==2.13.3
# via jaxtyping
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0

View File

@ -4,8 +4,9 @@ import functools
import typing as typ
from types import MappingProxyType
# import colour ## TODO
import numpy as np
import jax
import jax.numpy as jnp
import numba
import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx
@ -15,66 +16,46 @@ from ....utils import sci_constants as constants
from .socket_types import SocketType
class DataFlowKind(enum.StrEnum):
"""Defines a shape/kind of data that may flow through a node tree.
class FlowKind(enum.StrEnum):
"""Defines a kind of data that can flow between nodes.
Since a node socket may define one of each, we can support several related kinds of data flow through the same node-graph infrastructure.
Each node link can be thought to contain **multiple pipelines for data to flow along**.
Each pipeline is cached incrementally, and independently, of the others.
Thus, the same socket can easily support several kinds of related data flow at the same time.
Attributes:
Value: A value without any unknown symbols.
- Basic types aka. float, int, list, string, etc. .
- Exotic (immutable-ish) types aka. numpy array, KDTree, etc. .
- A usable constructed object, ex. a `tidy3d.Box`.
- Expressions (`sp.Expr`) that don't have unknown variables.
- Lazy sequences aka. generators, with all data bound.
SpectralValue: A value defined along a spectral range.
- {`np.array`
LazyValue: An object which, when given new data, can make many values.
- An `sp.Expr`, which might need `simplify`ing, `jax` JIT'ing, unit cancellations, variable substitutions, etc. before use.
- Lazy objects, for which all parameters aren't yet known.
- A computational graph aka. `aesara`, which may even need to be handled before
Capabilities: A `ValueCapability` object providing compatibility.
# Value Data Flow
Simply passing values is the simplest and easiest use case.
This doesn't mean it's "dumb" - ex. a `sp.Expr` might, before use, have `simplify`, rewriting, unit cancellation, etc. run.
All of this is okay, as long as there is no *introduction of new data* ex. variable substitutions.
# Lazy Value Data Flow
By passing (essentially) functions, one supports:
- **Lightness**: While lazy values can be made expensive to construct, they will generally not be nearly as heavy to handle when trying to work with ex. operations on voxel arrays.
- **Performance**: Parameterizing ex. `sp.Expr` with variables allows one to build very optimized functions, which can make ex. node graph updates very fast if the only operation run is the `jax` JIT'ed function (aka. GPU accelerated) generated from the final full expression.
- **Numerical Stability**: Libraries like `aesara` build a computational graph, which can be automatically rewritten to avoid many obvious conditioning / cancellation errors.
- **Lazy Output**: The goal of a node-graph may not be the definition of a single value, but rather, a parameterized expression for generating *many values* with known properties. This is especially interesting for use cases where one wishes to build an optimization step using nodes.
# Capability Passing
By being able to pass "capabilities" next to other kinds of values, nodes can quickly determine whether a given link is valid without having to actually compute it.
# Lazy Parameter Value
When using parameterized LazyValues, one may wish to independently pass parameter values through the graph, so they can be inserted into the final (cached) high-performance expression without.
The advantage of using a different data flow would be changing this kind of value would ONLY invalidate lazy parameter value caches, which would allow an incredibly fast path of getting the value into the lazy expression for high-performance computation.
Implementation TBD - though, ostensibly, one would have a "parameter" node which both would only provide a LazyValue (aka. a symbolic variable), but would also be able to provide a LazyParamValue, which would be a particular value of some kind (probably via the `value` of some other node socket).
Capabilities: Describes a socket's linkeability with other sockets.
Links between sockets with incompatible capabilities will be rejected.
This doesn't need to be defined normally, as there is a default.
However, in some cases, defining it manually to control linkeability more granularly may be desirable.
Value: A generic object, which is "directly usable".
This should be chosen when a more specific flow kind doesn't apply.
Array: An object with dimensions, and possibly a unit.
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
LazyValueFunc: A composable function.
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
This should be used instead of `Array` whenever possible.
Param: An object providing data to complete `Lazy` data.
For example,
Info: An object providing context about other flows.
For example,
"""
Capabilities = enum.auto()
# Values
Value = enum.auto()
ValueArray = enum.auto()
ValueSpectrum = enum.auto()
Array = enum.auto()
# Lazy
LazyValue = enum.auto()
LazyValueRange = enum.auto()
LazyValueSpectrum = enum.auto()
LazyArrayRange = enum.auto()
# Auxiliary
Param = enum.auto()
Info = enum.auto()
@classmethod
def scale_to_unit_system(cls, kind: typ.Self, value, socket_type, unit_system):
@ -85,7 +66,7 @@ class DataFlowKind(enum.StrEnum):
unit_system[socket_type],
)
)
if kind == cls.LazyValueRange:
if kind == cls.LazyArrayRange:
return value.rescale_to_unit(unit_system[socket_type])
msg = 'Tried to scale unknown kind'
@ -93,12 +74,12 @@ class DataFlowKind(enum.StrEnum):
####################
# - Data Structures: Capabilities
# - Capabilities
####################
@dataclasses.dataclass(frozen=True, kw_only=True)
class DataCapabilities:
class CapabilitiesFlow:
socket_type: SocketType
active_kind: DataFlowKind
active_kind: FlowKind
is_universal: bool = False
@ -110,13 +91,16 @@ class DataCapabilities:
####################
# - Data Structures: Non-Lazy
# - Value
####################
DataValue: typ.TypeAlias = typ.Any
ValueFlow: typ.TypeAlias = typ.Any
####################
# - Value Array
####################
@dataclasses.dataclass(frozen=True, kw_only=True)
class DataValueArray:
class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit.
Attributes:
@ -125,69 +109,105 @@ class DataValueArray:
None if unitless.
"""
values: typ.Sequence[DataValue]
values: jax.Array
unit: spu.Quantity | None
####################
# - Lazy Value Func
####################
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], ValueFlow]
@dataclasses.dataclass(frozen=True, kw_only=True)
class DataValueSpectrum:
"""A numerical representation of a spectral distribution.
class LazyValueFuncFlow:
r"""Encapsulates a lazily evaluated data value as a composable function with bound and free arguments.
- **Bound Args**: Arguments that are realized when **defining** the lazy value.
Both positional values and keyword values are supported.
- **Free Args**: Arguments that are specified when evaluating the lazy value.
Both positional values and keyword values are supported.
The **root function** is encapsulated using `from_function`, and must accept arguments in the following order:
$$
f_0:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \underbrace{r_1, r_2, ...}_{\text{Free}}) \to \text{output}_0
$$
Subsequent **composed functions** are encapsulated from the _root function_, and are created with `root_function.compose`.
They must accept arguments in the following order:
$$
f_k:\ \ \ \ (\underbrace{b_1, b_2, ...}_{\text{Bound}}\ ,\ \text{output}_{k-1} ,\ \underbrace{r_p, r_{p+1}, ...}_{\text{Free}}) \to \text{output}_k
$$
Attributes:
wls: A 1D `numpy` float array of wavelength values.
wls_unit: The unit of wavelengths, as length dimension.
values: A 1D `numpy` float array of values corresponding to wavelength values.
values_unit: The unit of the value, as arbitrary dimension.
freqs_unit: The unit of the value, as arbitrary dimension.
function: The function to be lazily evaluated.
bound_args: Arguments that will be packaged into function, which can't be later modifier.
func_kwargs: Arguments to be specified by the user at the time of use.
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
supports_numba: Whether the contained `self.function` can be compiled with Numba's JIT compiler.
"""
# Wavelength
wls: np.array
wls_unit: spu.Quantity
func: LazyFunction
func_kwargs: dict[str, type]
supports_jax: bool = False
supports_numba: bool = False
# Value
values: np.array
values_unit: spu.Quantity
@staticmethod
def from_func(
func: LazyFunction,
supports_jax: bool = False,
supports_numba: bool = False,
**func_kwargs: dict[str, type],
) -> typ.Self:
return LazyValueFuncFlow(
func=func,
func_kwargs=func_kwargs,
supports_jax=supports_jax,
supports_numba=supports_numba,
)
# Frequency
freqs_unit: spu.Quantity = spu.hertz
# Composition
def compose_within(
self,
enclosing_func: LazyFunction,
supports_jax: bool = False,
supports_numba: bool = False,
**enclosing_func_kwargs: dict[str, type],
) -> typ.Self:
return LazyValueFuncFlow(
function=lambda **kwargs: enclosing_func(
self.func(**{k: v for k, v in kwargs if k in self.func_kwargs}),
**kwargs,
),
func_kwargs=self.func_kwargs | enclosing_func_kwargs,
supports_jax=self.supports_jax and supports_jax,
supports_numba=self.supports_numba and supports_numba,
)
@functools.cached_property
def freqs(self) -> np.array:
"""The spectral frequencies, computed from the wavelengths.
def func_jax(self) -> LazyFunction:
if self.supports_jax:
return jax.jit(self.func)
Frequencies are NOT reversed, so as to preserve the by-index mapping to `DataValueSpectrum.values`.
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)
Returns:
Frequencies, as a unitless `numpy` array.
Use `DataValueSpectrum.wls_unit` to interpret this return value.
"""
unitless_speed_of_light = spux.sympy_to_python(
spux.scale_to_unit(
constants.vac_speed_of_light, (self.wl_unit / self.freq_unit)
)
)
return unitless_speed_of_light / self.wls
@functools.cached_property
def func_numba(self) -> LazyFunction:
if self.supports_numba:
return numba.jit(self.func)
# TODO: Colour Library
# def as_colour_sd(self) -> colour.SpectralDistribution:
# """Returns the `colour` representation of this spectral distribution, ideal for plotting and colorimetric analysis."""
# return colour.SpectralDistribution(data=self.values, domain=self.wls)
msg = 'Can\'t express LazyValueFuncFlow as Numba function (using numba.jit), since "self.supports_numba" is False'
raise ValueError(msg)
####################
# - Data Structures: Lazy
# - Lazy Array Range
####################
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyDataValue:
callback: typ.Callable[[...], [DataValue]]
def realize(self, *args: list[DataValue]) -> DataValue:
return self.callback(*args)
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyDataValueRange:
class LazyArrayRangeFlow:
symbols: set[sp.Symbol]
start: sp.Basic
@ -200,7 +220,7 @@ class LazyDataValueRange:
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
if self.has_unit:
return LazyDataValueRange(
return LazyArrayRangeFlow(
symbols=self.symbols,
has_unit=self.has_unit,
unit=unit,
@ -219,7 +239,7 @@ class LazyDataValueRange:
reverse: bool = False,
) -> typ.Self:
"""Call a function on both bounds (start and stop), creating a new `LazyDataValueRange`."""
return LazyDataValueRange(
return LazyArrayRangeFlow(
symbols=self.symbols,
has_unit=self.has_unit,
unit=self.unit,
@ -234,8 +254,8 @@ class LazyDataValueRange:
)
def realize(
self, symbol_values: dict[sp.Symbol, DataValue] = MappingProxyType({})
) -> DataValueArray:
self, symbol_values: dict[sp.Symbol, ValueFlow] = MappingProxyType({})
) -> ArrayFlow:
# Realize Symbols
if not self.has_unit:
start = spux.sympy_to_python(self.start.subs(symbol_values))
@ -250,85 +270,25 @@ class LazyDataValueRange:
# Return Linspace / Logspace
if self.scaling == 'lin':
return DataValueArray(
values=np.linspace(start, stop, self.steps), unit=self.unit
return ArrayFlow(
values=jnp.linspace(start, stop, self.steps), unit=self.unit
)
if self.scaling == 'geom':
return DataValueArray(np.geomspace(start, stop, self.steps), self.unit)
return ArrayFlow(jnp.geomspace(start, stop, self.steps), self.unit)
if self.scaling == 'log':
return DataValueArray(np.logspace(start, stop, self.steps), self.unit)
return ArrayFlow(jnp.logspace(start, stop, self.steps), self.unit)
raise NotImplementedError
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyDataValueSpectrum:
wl_unit: spu.Quantity
value_unit: spu.Quantity
value_expr: sp.Expr
symbols: tuple[sp.Symbol, ...] = ()
freq_symbol: sp.Symbol = sp.Symbol('lamda') # noqa: RUF009
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
raise NotImplementedError
@functools.cached_property
def as_func(self) -> typ.Callable[[DataValue, ...], DataValue]:
"""Generates an optimized function for numerical evaluation of the spectral expression."""
return sp.lambdify([self.freq_symbol, *self.symbols], self.value_expr)
def realize(
self, wl_range: DataValueArray, symbol_values: tuple[DataValue, ...]
) -> DataValueSpectrum:
r"""Realizes the parameterized spectral function as a numerical spectral distribution.
Parameters:
wl_range: The lazy wavelength range to build the concrete spectral distribution with.
symbol_values: Numerical values for each symbol, in the same order as defined in `LazyDataValueSpectrum.symbols`.
The wavelength symbol ($\lambda$ by default) always goes first.
_This is used to call the spectral function using the output of `.as_func()`._
Returns:
The concrete, numerical spectral distribution.
"""
return DataValueSpectrum(
wls=wl_range.values,
wls_unit=self.wl_unit,
values=self.as_func(*list(symbol_values.values())),
values_unit=self.value_unit,
)
####################
# - Param
####################
ParamFlow: typ.TypeAlias = dict[str, typ.Any]
#
#
#####################
## - Data Pipeline
#####################
# @dataclasses.dataclass(frozen=True, kw_only=True)
# class DataPipelineDim:
# unit: spu.Quantity | None
#
# class DataPipelineDimType(enum.StrEnum):
# # Map Inputs
# Time = enum.auto()
# Freq = enum.auto()
# Space3D = enum.auto()
# DiffOrder = enum.auto()
#
# # Map Inputs
# Power = enum.auto()
# EVec = enum.auto()
# HVec = enum.auto()
# RelPerm = enum.auto()
#
#
# @dataclasses.dataclass(frozen=True, kw_only=True)
# class LazyDataPipeline:
# dims: list[DataPipelineDim]
#
# def _callable(self):
# """JITs the current pipeline of functions with `jax`."""
#
# def __call__(self):
# pass
####################
# - Lazy Value Func
####################
InfoFlow: typ.TypeAlias = dict[str, typ.Any]

View File

@ -7,6 +7,7 @@ SOCKET_COLORS = {
ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey
ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey
# Number
ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue
ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue

View File

@ -6,6 +6,7 @@ SOCKET_SHAPES = {
ST.Bool: 'CIRCLE',
ST.String: 'CIRCLE',
ST.FilePath: 'CIRCLE',
ST.Expr: 'CIRCLE',
# Number
ST.IntegerNumber: 'CIRCLE',
ST.RationalNumber: 'CIRCLE',

View File

@ -14,6 +14,7 @@ class SocketType(BlenderTypeEnum):
String = enum.auto()
FilePath = enum.auto()
Color = enum.auto()
Expr = enum.auto()
# Number
IntegerNumber = enum.auto()

View File

@ -38,8 +38,8 @@ def apply_colormap(normalized_data, colormap):
@jax.jit
def rgba_image_from_xyzf__viridis(xyz_freq):
amplitude = jnp.abs(jnp.squeeze(xyz_freq))
def rgba_image_from_2d_map__viridis(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
@ -49,8 +49,8 @@ def rgba_image_from_xyzf__viridis(xyz_freq):
@jax.jit
def rgba_image_from_xyzf__grayscale(xyz_freq):
amplitude = jnp.abs(jnp.squeeze(xyz_freq))
def rgba_image_from_2d_map__grayscale(map_2d):
amplitude = jnp.abs(map_2d)
amplitude_normalized = (amplitude - amplitude.min()) / (
amplitude.max() - amplitude.min()
)
@ -59,21 +59,19 @@ def rgba_image_from_xyzf__grayscale(xyz_freq):
return jnp.dstack((rgb_array, alpha_channel))
def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None):
"""RGBA Image from Squeezable XYZ-Freq w/fixed freq.
def rgba_image_from_2d_map(map_2d, colormap: str | None = None):
"""RGBA Image from a map of 2D coordinates to values.
Parameters:
xyz_freq: Shape (xlen, ylen, zlen), one dimension has length 1.
width_px: Pixel width to resize the image to.
height: Pixel height to resize the image to.
map_2d: Shape (width, height, value).
Returns:
Image as a JAX array of shape (height, width, 3)
Image as a JAX array of shape (height, width, 4)
"""
if colormap == 'VIRIDIS':
return rgba_image_from_xyzf__viridis(xyz_freq)
return rgba_image_from_2d_map__viridis(map_2d)
if colormap == 'GRAYSCALE':
return rgba_image_from_xyzf__grayscale(xyz_freq)
return rgba_image_from_2d_map__grayscale(map_2d)
class ManagedBLImage(base.ManagedObj):
@ -227,11 +225,11 @@ class ManagedBLImage(base.ManagedObj):
####################
# - Special Methods
####################
def xyzf_to_image(
self, xyz_freq, colormap: str | None = 'VIRIDIS', bl_select: bool = False
def map_2d_to_image(
self, map_2d, colormap: str | None = 'VIRIDIS', bl_select: bool = False
):
self.data_to_image(
lambda _: rgba_image_from_xyzf(xyz_freq, colormap=colormap),
lambda _: rgba_image_from_2d_map(map_2d, colormap=colormap),
bl_select=bl_select,
)

View File

@ -1,10 +1,12 @@
from . import extract_data, viz
from . import extract_data, viz, math
BL_REGISTER = [
*extract_data.BL_REGISTER,
*viz.BL_REGISTER,
*math.BL_REGISTER,
]
BL_NODES = {
**extract_data.BL_NODES,
**viz.BL_NODES,
**math.BL_NODES,
}

View File

@ -1,8 +1,11 @@
import typing as typ
import bpy
import jax.numpy as jnp
import sympy.physics.units as spu
from blender_maxwell.utils import jarray, logger
from .....utils import logger
from ... import contracts as ct
from ... import sockets
from .. import base, events
@ -229,8 +232,10 @@ class ExtractDataNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Data',
props={'sim_data__monitor_name', 'field_data__component'},
input_sockets={'Field Data'},
input_sockets_optional={'Field Data': True},
)
def compute_extracted_data(self, props: dict):
def compute_extracted_data(self, props: dict, input_sockets: dict):
if self.active_socket_set == 'Sim Data':
if (
CACHE_SIM_DATA.get(self.instance_id) is None
@ -242,12 +247,21 @@ class ExtractDataNode(base.MaxwellSimNode):
return sim_data.monitor_data[props['sim_data__monitor_name']]
elif self.active_socket_set == 'Field Data': # noqa: RET505
field_data = self._compute_input('Field Data')
return getattr(field_data, props['field_data__component'])
xarr = getattr(input_sockets['Field Data'], props['field_data__component'])
return jarray.JArray.from_xarray(
xarr,
dim_units={
'x': spu.um,
'y': spu.um,
'z': spu.um,
'f': spu.hertz,
},
)
elif self.active_socket_set == 'Flux Data':
flux_data = self._compute_input('Flux Data')
return flux_data.flux
return jnp.array(flux_data.flux)
msg = f'Tried to get data from unknown output socket in "{self.bl_label}"'
raise RuntimeError(msg)

View File

@ -0,0 +1,14 @@
from . import map_math, filter_math, reduce_math, operate_math
BL_REGISTER = [
*map_math.BL_REGISTER,
*filter_math.BL_REGISTER,
*reduce_math.BL_REGISTER,
*operate_math.BL_REGISTER,
]
BL_NODES = {
**map_math.BL_NODES,
**filter_math.BL_NODES,
**reduce_math.BL_NODES,
**operate_math.BL_NODES,
}

View File

@ -0,0 +1,121 @@
import functools
import typing as typ
import bpy
import jax
import jax.numpy as jnp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
# @functools.partial(jax.jit, static_argnames=('fixed_axis', 'fixed_axis_value'))
# jax.jit
def fix_axis(data, fixed_axis: int, fixed_axis_value: float):
log.critical(data.shape)
# Select Values of Fixed Axis
fixed_axis_values = data[
tuple(slice(None) if i == fixed_axis else 0 for i in range(data.ndim))
]
log.critical(fixed_axis_values)
# Compute Nearest Index on Fixed Axis
idx_of_nearest = jnp.argmin(jnp.abs(fixed_axis_values - fixed_axis_value))
log.critical(idx_of_nearest)
# Select Values along Fixed Axis Value
return jnp.take(data, idx_of_nearest, axis=fixed_axis)
class FilterMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.FilterMath
bl_label = 'Filter Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Axis Value': {
'Axis': sockets.IntegerNumberSocketDef(),
'Value': sockets.RealNumberSocketDef(),
},
'By Axis': {
'Axis': sockets.IntegerNumberSocketDef(),
},
## TODO: bool arrays for comparison/switching/sparse 0-setting/etc. .
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to reduce the input axis with',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Axis Value':
items += [
('FIX', 'Fix Coordinate', '(*, N, *) -> (*, *)'),
]
if self.active_socket_set == 'By Axis':
items += [
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
]
else:
items += [('NONE', 'None', 'No operations...')]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Axis Expr':
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'operation', 'active_socket_set'},
input_sockets={'Data', 'Axis', 'Value'},
input_sockets_optional={'Axis': True, 'Value': True},
)
def compute_data(self, props: dict, input_sockets: dict):
if not hasattr(input_sockets['Data'], 'shape'):
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
raise ValueError(msg)
# By Axis Value
if props['active_socket_set'] == 'By Axis Value':
if props['operation'] == 'FIX':
return fix_axis(
input_sockets['Data'], input_sockets['Axis'], input_sockets['Value']
)
# By Axis
if props['active_socket_set'] == 'By Axis':
if props['operation'] == 'SQUEEZE':
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
msg = 'Operation invalid'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
FilterMathNode,
]
BL_NODES = {ct.NodeType.FilterMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,164 @@
import typing as typ
import bpy
import jax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class MapMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.MapMath
bl_label = 'Map Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Element': {},
'By Vector': {},
'By Matrix': {},
'Expr': {
'Mapper': sockets.ExprSocketDef(
symbols=[sp.Symbol('x')],
default_expr=sp.Symbol('x'),
),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to apply to the input',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Element':
items += [
# General
('REAL', 'real', '(L) (by el)'),
('IMAG', 'imag', 'Im(L) (by el)'),
('ABS', 'abs', '|L| (by el)'),
('SQ', 'square', 'L^2 (by el)'),
('SQRT', 'sqrt', 'sqrt(L) (by el)'),
('INV_SQRT', '1/sqrt', '1/sqrt(L) (by el)'),
# Trigonometry
('COS', 'cos', 'cos(L) (by el)'),
('SIN', 'sin', 'sin(L) (by el)'),
('TAN', 'tan', 'tan(L) (by el)'),
('ACOS', 'acos', 'acos(L) (by el)'),
('ASIN', 'asin', 'asin(L) (by el)'),
('ATAN', 'atan', 'atan(L) (by el)'),
]
elif self.active_socket_set in 'By Vector':
items += [
('NORM_2', '2-Norm', '||L||_2 (by Vec)'),
]
elif self.active_socket_set == 'By Matrix':
items += [
# Matrix -> Number
('DET', 'Determinant', 'det(L) (by Mat)'),
('COND', 'Condition', 'κ(L) (by Mat)'),
('NORM_FRO', 'Frobenius Norm', '||L||_F (by Mat)'),
('RANK', 'Rank', 'rank(L) (by Mat)'),
# Matrix -> Array
('DIAG', 'Diagonal', 'diag(L) (by Mat)'),
('EIG_VALS', 'Eigenvalues', 'eigvals(L) (by Mat)'),
('SVD_VALS', 'SVD', 'svd(L) -> diag(Σ) (by Mat)'),
# Matrix -> Matrix
('INV', 'Invert', 'L^(-1) (by Mat)'),
('TRA', 'Transpose', 'L^T (by Mat)'),
# Matrix -> Matrices
('QR', 'QR', 'L -> Q·R (by Mat)'),
('CHOL', 'Cholesky', 'L -> L·Lh (by Mat)'),
('SVD', 'SVD', 'L -> U·Σ·Vh (by Mat)'),
]
else:
items += ['EXPR_EL', 'Expr (by el)', 'Expression-defined (by el)']
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set not in {'Expr (Element)'}:
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'active_socket_set', 'operation'},
input_sockets={'Data', 'Mapper'},
input_socket_kinds={'Mapper': ct.DataFlowKind.LazyValue},
input_sockets_optional={'Mapper': True},
)
def compute_data(self, props: dict, input_sockets: dict):
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
'By Element': {
'REAL': lambda data: jnp.real(data),
'IMAG': lambda data: jnp.imag(data),
'ABS': lambda data: jnp.abs(data),
'SQ': lambda data: jnp.square(data),
'SQRT': lambda data: jnp.sqrt(data),
'INV_SQRT': lambda data: 1 / jnp.sqrt(data),
'COS': lambda data: jnp.cos(data),
'SIN': lambda data: jnp.sin(data),
'TAN': lambda data: jnp.tan(data),
'ACOS': lambda data: jnp.acos(data),
'ASIN': lambda data: jnp.asin(data),
'ATAN': lambda data: jnp.atan(data),
'SINC': lambda data: jnp.sinc(data),
},
'By Vector': {
'NORM_2': lambda data: jnp.norm(data, ord=2, axis=-1),
},
'By Matrix': {
# Matrix -> Number
'DET': lambda data: jnp.linalg.det(data),
'COND': lambda data: jnp.linalg.cond(data),
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
'RANK': lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec
'DIAG': lambda data: jnp.diag(data),
'EIG_VALS': lambda data: jnp.eigvals(data),
'SVD_VALS': lambda data: jnp.svdvals(data),
# Matrix -> Matrix
'INV': lambda data: jnp.inv(data),
'TRA': lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices
'QR': lambda data: jnp.inv(data),
'CHOL': lambda data: jnp.linalg.cholesky(data),
'SVD': lambda data: jnp.linalg.svd(data),
},
'By El (Expr)': {
'EXPR_EL': lambda data: input_sockets['Mapper'](data),
},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data
return input_sockets['Data'].compose(
function=mapping_func,
)
####################
# - Blender Registration
####################
BL_REGISTER = [
MapMathNode,
]
BL_NODES = {ct.NodeType.MapMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,138 @@
import typing as typ
import bpy
import jax.numpy as jnp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class OperateMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.OperateMath
bl_label = 'Operate Math'
input_socket_sets: typ.ClassVar = {
'Elementwise': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
## TODO: Filter-array building operations
'Vec-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
'Mat-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to apply to the two inputs',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'Elementwise':
items = [
('ADD', 'Add', 'L + R (by el)'),
('SUB', 'Subtract', 'L - R (by el)'),
('MUL', 'Multiply', 'L · R (by el)'),
('DIV', 'Divide', 'L ÷ R (by el)'),
('POW', 'Power', 'L^R (by el)'),
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
]
elif self.active_socket_set in 'Vec | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('CROSS', 'Cross', 'L x R (by last-axis'),
]
elif self.active_socket_set == 'Mat | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, 'operation')
####################
# - Properties
####################
@events.computes_output_socket(
'Data',
props={'operation'},
input_sockets={'Data L', 'Data R'},
)
def compute_data(self, props: dict, input_sockets: dict):
if self.active_socket_set == 'Elementwise':
# Element-Wise Arithmetic
if props['operation'] == 'ADD':
return input_sockets['Data L'] + input_sockets['Data R']
if props['operation'] == 'SUB':
return input_sockets['Data L'] - input_sockets['Data R']
if props['operation'] == 'MUL':
return input_sockets['Data L'] * input_sockets['Data R']
if props['operation'] == 'DIV':
return input_sockets['Data L'] / input_sockets['Data R']
# Element-Wise Arithmetic
if props['operation'] == 'POW':
return input_sockets['Data L'] ** input_sockets['Data R']
# Binary Trigonometry
if props['operation'] == 'ATAN2':
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
# Special Functions
if props['operation'] == 'HEAVISIDE':
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
# Linear Algebra
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
if props['operation'] == 'DOT':
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
elif self.active_socket_set == 'Vec-Vec':
if props['operation'] == 'CROSS':
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
elif self.active_socket_set == 'Mat-Vec':
if props['operation'] == 'LIN_SOLVE':
return jnp.linalg.lstsq(
input_sockets['Data L'], input_sockets['Data R']
)
if props['operation'] == 'LSQ_SOLVE':
return jnp.linalg.solve(
input_sockets['Data L'], input_sockets['Data R']
)
msg = 'Invalid operation'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
OperateMathNode,
]
BL_NODES = {ct.NodeType.OperateMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -0,0 +1,135 @@
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class ReduceMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.ReduceMath
bl_label = 'Reduce Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
'Axis': sockets.IntegerNumberSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Axis': {
'Axis': sockets.IntegerNumberSocketDef(),
},
'Expr': {
'Reducer': sockets.ExprSocketDef(
symbols=[sp.Symbol('a'), sp.Symbol('b')],
default_expr=sp.Symbol('a') + sp.Symbol('b'),
),
'Axis': sockets.IntegerNumberSocketDef(),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to reduce the input axis with',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.sync_prop('operation', context),
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Axis':
items += [
# Accumulation
('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'),
('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'),
('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'),
('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'),
('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'),
# Stats
('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'),
('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'),
('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'),
('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'),
# Dimension Reduction
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
]
else:
items += [('NONE', 'None', 'No operations...')]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Axis Expr':
layout.prop(self, 'operation')
####################
# - Compute
####################
@events.computes_output_socket(
'Data',
props={'operation'},
input_sockets={'Data', 'Axis', 'Reducer'},
input_socket_kinds={'Reducer': ct.DataFlowKind.LazyValue},
input_sockets_optional={'Reducer': True},
)
def compute_data(self, props: dict, input_sockets: dict):
if not hasattr(input_sockets['Data'], 'shape'):
msg = 'Input socket "Data" must be an N-D Array (with a "shape" attribute)'
raise ValueError(msg)
if self.active_socket_set == 'Axis Expr':
ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1)
return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis'])
if self.active_socket_set == 'By Axis':
## Dimension Reduction
# ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
# Accumulation
if props['operation'] == 'SUM':
return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'PROD':
return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MIN':
return jnp.min(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MAX':
return jnp.max(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'P2P':
return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis'])
# Stats
if props['operation'] == 'MEAN':
return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MEDIAN':
return jnp.median(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'STDDEV':
return jnp.std(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'VARIANCE':
return jnp.var(input_sockets['Data'], axis=input_sockets['Axis'])
# Dimension Reduction
if props['operation'] == 'SQUEEZE':
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
msg = 'Operation invalid'
raise ValueError(msg)
####################
# - Blender Registration
####################
BL_REGISTER = [
ReduceMathNode,
]
BL_NODES = {ct.NodeType.ReduceMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -20,7 +20,7 @@ class VizNode(base.MaxwellSimNode):
####################
# - Sockets
####################
input_sockets = {
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
'Freq': sockets.PhysicalFreqSocketDef(),
}
@ -72,20 +72,12 @@ class VizNode(base.MaxwellSimNode):
props: dict,
unit_systems: dict,
):
selected_data = jnp.array(
input_sockets['Data'].sel(f=input_sockets['Freq'], method='nearest')
)
managed_objs['plot'].xyzf_to_image(
selected_data,
managed_objs['plot'].map_2d_to_image(
input_sockets['Data'].as_bound_jax_func(),
colormap=props['colormap'],
bl_select=True,
)
# @events.on_init()
# def on_init(self):
# self.on_changed_inputs()
####################
# - Blender Registration

View File

@ -42,6 +42,15 @@ class SocketDef(pyd.BaseModel, abc.ABC):
# - SocketDef
####################
class MaxwellSimSocket(bpy.types.NodeSocket):
"""A specialized Blender socket for nodes in a Maxwell simulation.
Attributes:
instance_id: A unique ID attached to a particular socket instance.
Guaranteed to be unchanged so long as the socket lives.
Used as a socket-specific cache index.
locked: The lock-state of a particular socket, which determines the socket's user editability
"""
# Fundamentals
socket_type: ct.SocketType
bl_label: str
@ -73,21 +82,53 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
####################
# - Initialization
####################
def __init_subclass__(cls, **kwargs: typ.Any):
super().__init_subclass__(**kwargs)
@classmethod
def set_prop(
cls,
prop_name: str,
prop: bpy.types.Property,
no_update: bool = False,
update_with_name: str | None = None,
**kwargs,
) -> None:
"""Adds a Blender property to a class via `__annotations__`, so it initializes with any subclass.
# Setup Blender ID for Node
if not hasattr(cls, 'socket_type'):
msg = f"Socket class {cls} does not define 'socket_type'"
raise ValueError(msg)
cls.bl_idname = str(cls.socket_type.value)
Notes:
- Blender properties can't be set within `__init_subclass__` simply by adding attributes to the class; they must be added as type annotations.
- Must be called **within** `__init_subclass__`.
# Setup Locked Property for Node
cls.__annotations__['locked'] = bpy.props.BoolProperty(
name='Locked State',
description="The lock-state of a particular socket, which determines the socket's user editability",
default=False,
Parameters:
name: The name of the property to set.
prop: The `bpy.types.Property` to instantiate and attach..
no_update: Don't attach a `self.sync_prop()` callback to the property's `update`.
"""
_update_with_name = prop_name if update_with_name is None else update_with_name
extra_kwargs = (
{
'update': lambda self, context: self.sync_prop(
_update_with_name, context
),
}
if not no_update
else {}
)
cls.__annotations__[prop_name] = prop(
**kwargs,
**extra_kwargs,
)
def __init_subclass__(cls, **kwargs: typ.Any):
log.debug('Initializing Socket: %s', cls.socket_type)
super().__init_subclass__(**kwargs)
# cls._assert_attrs_valid()
# Socket Properties
## Identifiers
cls.bl_idname: str = str(cls.socket_type.value)
cls.set_prop('instance_id', bpy.props.StringProperty, no_update=True)
## Special States
cls.set_prop('locked', bpy.props.BoolProperty, no_update=True, default=False)
# Setup Style
cls.socket_color = ct.SOCKET_COLORS[cls.socket_type]

View File

@ -1,11 +1,12 @@
from . import any as any_socket
from . import bool as bool_socket
from . import file_path, string
from . import expr, file_path, string
AnySocketDef = any_socket.AnySocketDef
BoolSocketDef = bool_socket.BoolSocketDef
FilePathSocketDef = file_path.FilePathSocketDef
StringSocketDef = string.StringSocketDef
FilePathSocketDef = file_path.FilePathSocketDef
ExprSocketDef = expr.ExprSocketDef
BL_REGISTER = [
@ -13,4 +14,5 @@ BL_REGISTER = [
*bool_socket.BL_REGISTER,
*string.BL_REGISTER,
*file_path.BL_REGISTER,
*expr.BL_REGISTER,
]

View File

@ -0,0 +1,80 @@
import bpy
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils.pydantic_sympy import SympyExpr
from ... import bl_cache
from ... import contracts as ct
from .. import base
class ExprBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.Expr
bl_label = 'Expr'
####################
# - Properties
####################
raw_value: bpy.props.StringProperty(
name='Expr',
description='Represents a symbolic expression',
default='',
update=(lambda self, context: self.sync_prop('raw_value', context)),
)
symbols: list[sp.Symbol] = bl_cache.BLField([])
## TODO: Way of assigning assumptions to symbols.
## TODO: Dynamic add/remove of symbols
####################
# - Socket UI
####################
def draw_value(self, col: bpy.types.UILayout) -> None:
col.prop(self, 'raw_value', text='')
####################
# - Computation of Default Value
####################
@property
def value(self) -> sp.Expr:
return sp.sympify(
self.raw_value,
strict=False,
convert_xor=True,
).subs(spux.ALL_UNIT_SYMBOLS)
@value.setter
def value(self, value: str) -> None:
self.raw_value = str(value)
@property
def lazy_value(self) -> sp.Expr:
return ct.LazyDataValue.from_function(
sp.lambdify(self.symbols, self.value, 'jax'),
free_args=(tuple(str(sym) for sym in self.symbols), frozenset()),
supports_jax=True,
)
####################
# - Socket Configuration
####################
class ExprSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.Expr
_x = sp.Symbol('x', real=True)
symbols: list[SympyExpr] = [_x]
default_expr: SympyExpr = _x
def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.value = self.default_expr
bl_socket.symbols = self.symbols
####################
# - Blender Registration
####################
BL_REGISTER = [
ExprBLSocket,
]

View File

@ -57,3 +57,6 @@ class StringSocketDef(base.SocketDef):
BL_REGISTER = [
StringBLSocket,
]

View File

@ -0,0 +1,63 @@
import dataclasses
import typing as typ
from types import MappingProxyType
import jax
import jax.numpy as jnp
import pandas as pd
# import jaxtyping as jtyp
import sympy.physics.units as spu
import xarray
from . import extra_sympy_units as spux
from . import logger
log = logger.get(__name__)
DimName: typ.TypeAlias = str
Number: typ.TypeAlias = int | float | complex
NumberRange: typ.TypeAlias = jax.Array
@dataclasses.dataclass(kw_only=True)
class JArray:
"""Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds."""
array: jax.Array
dims: dict[DimName, NumberRange]
dim_units: dict[DimName, spu.Quantity]
####################
# - Constructor
####################
@classmethod
def from_xarray(
cls,
xarr: xarray.DataArray,
dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}),
sort_axis: int = -1,
) -> typ.Self:
return cls(
array=jnp.sort(jnp.array(xarr.data), axis=sort_axis),
dims={
dim_name: jnp.array(xarr.get_index(dim_name).values)
for dim_name in xarr.dims
},
dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims},
)
def idx(self, dim_name: DimName, dim_value: Number) -> int:
found_idx = jnp.searchsorted(self.dims[dim_name], dim_value)
if found_idx == 0:
return found_idx
if found_idx == len(self.dims[dim_name]):
return found_idx - 1
left = self.dims[dim_name][found_idx - 1]
right = self.dims[dim_name][found_idx - 1]
return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx
@property
def dtype(self) -> jnp.dtype:
return self.array.dtype