refactor: Factored out flow_kinds.py for clarity.

main
Sofus Albert Høgsbro Rose 2024-05-04 11:25:11 +02:00
parent b221f9ae2b
commit 3a53e4ce46
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
12 changed files with 1047 additions and 1011 deletions

Binary file not shown.

View File

@ -0,0 +1,17 @@
from .array import ArrayFlow
from .capabiltiies import CapabilitiesFlow
from .flow_kinds import FlowKind
from .lazy_array_range import LazyArrayRange
from .lazy_value_func import LazyValueFunc
from .params import Params
from .value import ValueFlow
__all__ = [
'ArrayFlow',
'CapabilitiesFlow',
'FlowKind',
'LazyArrayRange',
'LazyValueFunc',
'Params',
'ValueFlow',
]

View File

@ -0,0 +1,96 @@
import dataclasses
import functools
import typing as typ
import jaxtyping as jtyp
import numpy as np
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit.
Attributes:
values: An ND array-like object of arbitrary numerical type.
unit: A `sympy` unit.
None if unitless.
"""
values: jtyp.Shaped[jtyp.Array, '...']
unit: spux.Unit | None = None
is_sorted: bool = False
def __len__(self) -> int:
return len(self.values)
@functools.cached_property
def mathtype(self) -> spux.MathType:
return spux.MathType.from_pytype(type(self.values.item(0)))
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value.
Units are taken into account; the given value will be scaled to the internal unit before direct use.
Parameters:
require_sorted: Require that `self.values` be sorted, so that use of the faster binary-search algorithm is guaranteed.
Returns:
The index of `self.values` that is closest to the value `value`.
"""
if not require_sorted:
raise NotImplementedError
# Scale Given Value to Internal Unit
scaled_value = spux.sympy_to_python(spux.scale_to_unit(value, self.unit))
# BinSearch for "Right IDX"
## >>> self.values[right_idx] > scaled_value
## >>> self.values[right_idx - 1] < scaled_value
right_idx = np.searchsorted(self.values, scaled_value, side='left')
# Case: Right IDX is Boundary
if right_idx == 0:
return right_idx
if right_idx == len(self.values):
return right_idx - 1
# Find Closest of [Right IDX - 1, Right IDX]
left_val = self.values[right_idx - 1]
right_val = self.values[right_idx]
if (scaled_value - left_val) <= (right_val - scaled_value):
return right_idx - 1
return right_idx
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self:
if self.unit is not None:
return ArrayFlow(
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
if self.unit is not None:
return ArrayFlow(
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
unit=unit,
is_sorted=self.is_sorted, ## TODO: Can we really say that?
)
## TODO: Is this scaling numerically stable?
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale_to_unit_system(self, unit: spu.Quantity) -> typ.Self:
raise NotImplementedError

View File

@ -0,0 +1,39 @@
import dataclasses
import typing as typ
from ..socket_types import SocketType
from .flow_kinds import FlowKind
@dataclasses.dataclass(frozen=True, kw_only=True)
class CapabilitiesFlow:
socket_type: SocketType
active_kind: FlowKind
is_universal: bool = False
# == Constraint
must_match: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
# ∀b (b ∈ A) Constraint
## A: allow_any
## b∈B: present_any
allow_any: set[typ.Any] = dataclasses.field(default_factory=set)
present_any: set[typ.Any] = dataclasses.field(default_factory=set)
def is_compatible_with(self, other: typ.Self) -> bool:
return other.is_universal or (
self.socket_type == other.socket_type
and self.active_kind == other.active_kind
# == Constraint
and all(
name in other.must_match
and self.must_match[name] == other.must_match[name]
for name in self.must_match
)
# ∀b (b ∈ A) Constraint
and (
self.present_any & other.allow_any
or (not self.present_any and not self.allow_any)
)
)

View File

@ -0,0 +1,68 @@
import enum
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
class FlowKind(enum.StrEnum):
"""Defines a kind of data that can flow between nodes.
Each node link can be thought to contain **multiple pipelines for data to flow along**.
Each pipeline is cached incrementally, and independently, of the others.
Thus, the same socket can easily support several kinds of related data flow at the same time.
Attributes:
Capabilities: Describes a socket's linkeability with other sockets.
Links between sockets with incompatible capabilities will be rejected.
This doesn't need to be defined normally, as there is a default.
However, in some cases, defining it manually to control linkeability more granularly may be desirable.
Value: A generic object, which is "directly usable".
This should be chosen when a more specific flow kind doesn't apply.
Array: An object with dimensions, and possibly a unit.
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
LazyValueFunc: A composable function.
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
This should be used instead of `Array` whenever possible.
Param: A dictionary providing particular parameters for a lazy value.
Info: An dictionary providing extra context about any aspect of flow.
"""
Capabilities = enum.auto()
# Values
Value = enum.auto()
Array = enum.auto()
# Lazy
LazyValueFunc = enum.auto()
LazyArrayRange = enum.auto()
# Auxiliary
Params = enum.auto()
Info = enum.auto()
@classmethod
def scale_to_unit_system(
cls,
kind: typ.Self,
value,
unit_system: spux.UnitSystem,
):
if kind == cls.Value:
return spux.scale_to_unit_system(
value,
unit_system,
)
if kind == cls.LazyArrayRange:
return value.rescale_to_unit_system(unit_system)
if kind == cls.Params:
return value.rescale_to_unit_system(unit_system)
msg = 'Tried to scale unknown kind'
raise ValueError(msg)

View File

@ -0,0 +1,182 @@
import dataclasses
import functools
import typing as typ
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .array import ArrayFlow
from .lazy_array_range import LazyArrayRangeFlow
log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow:
# Dimension Information
dim_names: list[str] = dataclasses.field(default_factory=list)
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
default_factory=dict
) ## TODO: Rename to dim_idxs
@functools.cached_property
def dim_lens(self) -> dict[str, int]:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_mathtypes(self) -> dict[str, spux.MathType]:
return {
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property
def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_physical_types(self) -> dict[str, spux.PhysicalType]:
return {
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property
def dim_idx_arrays(self) -> list[jax.Array]:
return [
dim_idx.realize().values
if isinstance(dim_idx, LazyArrayRangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
# Output Information
## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
# Pinned Dimension Information
## TODO: Add PhysicalType
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict
)
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict
)
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
####################
# - Methods
####################
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx={
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
for _dim_name, dim_idx in self.dim_idx.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def delete_dimension(self, dim_name: str) -> typ.Self:
"""Delete a dimension."""
return InfoFlow(
# Dimensions
dim_names=[
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
],
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
"""Delete a dimension."""
# Compute Swapped Dimension Name List
def name_swapper(dim_name):
return (
dim_name
if dim_name not in [dim_0_name, dim_1_name]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
)
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
# Compute Info
return InfoFlow(
# Dimensions
dim_names=dim_names,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
"""Set the MathType of a particular output name."""
return InfoFlow(
dim_names=self.dim_names,
dim_idx=self.dim_idx,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
)
def collapse_output(
self,
collapsed_name: str,
collapsed_mathtype: spux.MathType,
collapsed_unit: spux.Unit,
) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx=self.dim_idx,
output_name=collapsed_name,
output_shape=None,
output_mathtype=collapsed_mathtype,
output_unit=collapsed_unit,
)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names[:-1],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
},
# Outputs
output_name=self.output_name,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)

View File

@ -0,0 +1,347 @@
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from .array import ArrayFlow
from .flow_kinds import FlowKind
from .lazy_value_func import LazyValueFuncFlow
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyArrayRangeFlow:
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation.
# Advantages
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
## Memory
`ArrayFlow` generally has a memory scaling of $O(n)$.
Naturally, `LazyArrayRangeFlow` is always constant, since only the boundaries and steps are stored.
## Symbolic
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol).
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow.
## Performant Unit-Aware Operations
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
`LazyArrayRangeFlow`, by contrast, scales as $O(1)$.
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
## High-Performance Composition and Gradiant
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
There are two nice things about this:
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
Thus, as part of larger computations, the performance properties of `LazyArrayRangeFlow` is extremely favorable.
## Numerical Properties
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes:
start: An expression generating a scalar, unitless, complex value for the array's lower bound.
_Integer, rational, and real values are also supported._
stop: An expression generating a scalar, unitless, complex value for the array's upper bound.
_Integer, rational, and real values are also supported._
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
scaling: The method of distributing `step` values between the two endpoints.
Generally, the linear default is sufficient.
unit: The unit of the generated array values
symbols: Set of variables from which `start` and/or `stop` are determined.
"""
start: spux.ScalarUnitlessComplexExpr
stop: spux.ScalarUnitlessComplexExpr
steps: int
scaling: typ.Literal['lin', 'geom', 'log'] = 'lin'
unit: spux.Unit | None = None
symbols: frozenset[spux.IntSymbol] = frozenset()
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
The order is guaranteed to be **deterministic**.
Returns:
All symbols valid for use in the expression.
"""
return sorted(self.symbols, key=lambda sym: sym.name)
@functools.cached_property
def mathtype(self) -> spux.MathType:
# Get Start Mathtype
if isinstance(self.start, spux.SympyType):
start_mathtype = spux.MathType.from_expr(self.start)
else:
start_mathtype = spux.MathType.from_pytype(type(self.start))
# Get Stop Mathtype
if isinstance(self.stop, spux.SympyType):
stop_mathtype = spux.MathType.from_expr(self.stop)
else:
stop_mathtype = spux.MathType.from_pytype(type(self.stop))
# Check Equal
if start_mathtype != stop_mathtype:
return spux.MathType.combine(start_mathtype, stop_mathtype)
return start_mathtype
def __len__(self):
return self.steps
####################
# - Units
####################
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
"""Replaces the unit without rescaling the unitless bounds.
Parameters:
corrected_unit: The unit to replace the current unit with.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
return LazyArrayRangeFlow(
start=self.start,
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
return LazyArrayRangeFlow(
start=spux.scale_to_unit(self.start * self.unit, unit),
stop=spux.scale_to_unit(self.stop * self.unit, unit),
steps=self.steps,
scaling=self.scaling,
unit=unit,
symbols=self.symbols,
)
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
"""Replaces the units, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
return LazyArrayRangeFlow(
start=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
stop=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
steps=self.steps,
scaling=self.scaling,
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
symbols=self.symbols,
)
msg = (
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
)
raise ValueError(msg)
####################
# - Bound Operations
####################
def rescale_bounds(
self,
rescale_func: typ.Callable[
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
],
reverse: bool = False,
) -> typ.Self:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `LazyArrayRangeFlow`.
"""
return LazyArrayRangeFlow(
start=rescale_func(self.start if not reverse else self.stop),
stop=rescale_func(self.stop if not reverse else self.start),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
####################
# - Lazy Representation
####################
@functools.cached_property
def array_generator(
self,
) -> typ.Callable[
[int | float | complex, int | float | complex, int],
jtyp.Inexact[jtyp.Array, ' steps'],
]:
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
Returns:
A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
"""
jnp_nspace = {
'lin': jnp.linspace,
'geom': jnp.geomspace,
'log': jnp.logspace,
}.get(self.scaling)
if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
return jnp_nspace
@functools.cached_property
def as_func(
self,
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `LazyValueFuncFlow` that, given the input symbols defined in `self.symbols`,
"""
# Compile JAX Functions for Start/End Expressions
## FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
*args: list[int | float | complex],
) -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_value_func(self) -> LazyValueFuncFlow:
"""Creates a `LazyValueFuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `LazyValueFuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `LazyValueFuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return LazyValueFuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
kind: typ.Literal[FlowKind.Array, FlowKind.LazyValueFunc] = FlowKind.Array,
) -> ArrayFlow | LazyValueFuncFlow:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `LazyArrayRangeFlow`.
"""
if not set(self.symbols).issubset(set(symbol_values.keys())):
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
raise ValueError(msg)
# Realize Symbols
realized_start = spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
realized_stop = spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
# Return Linspace / Logspace
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(realized_start, realized_stop, self.steps)
if kind == FlowKind.Array:
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
if kind == FlowKind.LazyValueFunc:
return LazyValueFuncFlow(func=gen_array, supports_jax=True)
msg = f'Invalid kind: {kind}'
raise TypeError(msg)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
return self.realize()

View File

@ -0,0 +1,194 @@
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyValueFuncFlow:
r"""Wraps a composable function, providing useful information and operations.
# Data Flow as Function Composition
When using nodes to do math, it can be a good idea to express a **flow of data as the composition of functions**.
Each node creates a new function, which uses the still-unknown (aka. **lazy**) output of the previous function to plan some calculations.
Some new arguments may also be added, of course.
## Root Function
Of course, one needs to select a "bottom" function, which has no previous function as input.
Thus, the first step is to define this **root function**:
$$
f_0:\ \ \ \ \biggl(
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
\underbrace{
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
...,
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
}_{\texttt{kwargs}}
\biggr) \to \text{output}_0
$$
We'll express this simple snippet like so:
```python
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
## 'A0', 'KV0' are of length 'p' and 'q'
def f_0(*args, **kwargs): ...
lazy_value_func_0 = LazyValueFuncFlow(
func=f_0,
func_args=[(a_i, type(a_i)) for a_i in A0],
func_kwargs={k: v for k,v in KV0},
)
output_0 = lazy_value_func.func(*A0_computed, **KV0_computed)
```
So far so good.
But of course, nothing interesting has really happened yet.
## Composing Functions
The key thing is the next step: The function that uses the result of $f_0$!
$$
f_1:\ \ \ \ \biggl(
f_0(...),\ \
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
\underbrace{\biggl\{
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
\biggr) \to \text{output}_1
$$
Notice that _$f_1$ needs the arguments of both $f_0$ and $f_1$_.
Tracking arguments is already getting out of hand; we already have to use `...` to keep it readeable!
But doing so with `LazyValueFunc` is not so complex:
```python
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
## 'A1', 'KV1' are therefore of length 'r' and 's'
def f_1(output_0, *args, **kwargs): ...
lazy_value_func_1 = lazy_value_func_0.compose_within(
enclosing_func=f_1,
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
enclosing_func_kwargs={k: type(v) for k,v in K1},
)
A_computed = A0_computed + A1_computed
KW_computed = KV0_computed + KV1_computed
output_1 = lazy_value_func_1.func(*A_computed, **KW_computed)
```
We only need the arguments to $f_1$, and `LazyValueFunc` figures out how to make one function with enough arguments to call both.
## Isn't Laying Functions Slow/Hard?
Imagine that each function represents the action of a node, each of which performs expensive calculations on huge `numpy` arrays (**as one does when processing electromagnetic field data**).
At the end, a node might run the entire procedure with all arguments:
```python
output_n = lazy_value_func_n.func(*A_all, **KW_all)
```
It's rough: Most non-trivial pipelines drown in the time/memory overhead of incremental `numpy` operations - individually fast, but collectively iffy.
The killer feature of `LazyValueFuncFlow` is a sprinkle of black magic:
```python
func_n_jax = lazy_value_func_n.func_jax
output_n = func_n_jax(*A_all, **KW_all) ## Runs on your GPU
```
What happened was, **the entire pipeline** was compiled and optimized for high performance on not just your CPU, _but also (possibly) your GPU_.
All the layered function calls and inefficient incremental processing is **transformed into a high-performance program**.
Thank `jax` - specifically, `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), which internally enables this magic with a single function call.
## Other Considerations
**Auto-Differentiation**: Incredibly, `jax.jit` isn't the killer feature of `jax`. The function that comes out of `LazyValueFuncFlow` can also be differentiated with `jax.grad` (read: high-performance Jacobians for optimizing input parameters).
Though designed for machine learning, there's no reason other fields can't enjoy their inventions!
**Impact of Independent Caching**: JIT'ing can be slow.
That's why `LazyValueFuncFlow` has its own `FlowKind` "lane", which means that **only changes to the processing procedures will cause recompilation**.
Generally, adjustable values that affect the output will flow via the `Param` "lane", which has its own incremental caching, and only meets the compiled function when it's "plugged in" for final evaluation.
The effect is a feeling of snappiness and interactivity, even as the volume of data grows.
Attributes:
func: The function that the object encapsulates.
bound_args: Arguments that will be packaged into function, which can't be later modifier.
func_kwargs: Arguments to be specified by the user at the time of use.
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
"""
func: LazyFunction
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=list
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=dict
)
supports_jax: bool = False
# Merging
def __or__(
self,
other: typ.Self,
):
return LazyValueFuncFlow(
func=lambda *args, **kwargs: (
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
other.func(
*list(args[len(self.func_args) :]),
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
)
# Composition
def compose_within(
self,
enclosing_func: LazyFunction,
enclosing_func_args: list[type] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
supports_jax: bool = False,
) -> typ.Self:
return LazyValueFuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
),
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
supports_jax=self.supports_jax and supports_jax,
)
@functools.cached_property
def func_jax(self) -> LazyFunction:
if self.supports_jax:
return jax.jit(self.func)
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)

View File

@ -0,0 +1,95 @@
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow:
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[spux.Symbol] = frozenset()
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
return sorted(self.symbols, key=lambda sym: sym.name)
####################
# - Scaled Func Args
####################
def scaled_func_args(
self,
unit_system: spux.UnitSystem,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
return [
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
if arg not in symbol_values
else symbol_values[arg]
for arg in self.func_args
]
def scaled_func_kwargs(
self,
unit_system: spux.UnitSystem,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
return {
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
if arg not in symbol_values
else symbol_values[arg]
for arg_name, arg in self.func_kwargs.items()
}
####################
# - Operations
####################
def __or__(
self,
other: typ.Self,
):
"""Combine two function parameter lists, such that the LHS will be concatenated with the RHS.
Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters.
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
"""
return ParamsFlow(
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
symbols=self.symbols | other.symbols,
)
def compose_within(
self,
enclosing_func_args: list[spux.SympyExpr] = (),
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
enclosing_symbols: frozenset[spux.Symbol] = frozenset(),
) -> typ.Self:
return ParamsFlow(
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols,
)

View File

@ -0,0 +1,3 @@
import typing as typ
ValueFlow: typ.TypeAlias = typ.Any

View File

@ -116,7 +116,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
size=input_sockets['Size'], size=input_sockets['Size'],
name=props['sim_node_name'], name=props['sim_node_name'],
interval_space=(1, 1, 1), interval_space=(1, 1, 1),
freqs=input_sockets['Freqs'].realize().values, freqs=input_sockets['Freqs'].realize_array,
normal_dir='+' if input_sockets['Direction'] else '-', normal_dir='+' if input_sockets['Direction'] else '-',
) )
@ -142,11 +142,11 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
@events.on_value_changed( @events.on_value_changed(
# Trigger # Trigger
socket_name={'Center', 'Size'}, socket_name={'Center', 'Size', 'Direction'},
run_on_init=True, run_on_init=True,
# Loaded # Loaded
managed_objs={'mesh', 'modifier'}, managed_objs={'mesh', 'modifier'},
input_sockets={'Center', 'Size'}, input_sockets={'Center', 'Size', 'Direction'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={ scale_input_sockets={
'Center': 'BlenderUnits', 'Center': 'BlenderUnits',
@ -167,6 +167,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
'unit_system': unit_systems['BlenderUnits'], 'unit_system': unit_systems['BlenderUnits'],
'inputs': { 'inputs': {
'Size': input_sockets['Size'], 'Size': input_sockets['Size'],
'Direction': input_sockets['Direction'],
}, },
}, },
) )