refactor: Factored out flow_kinds.py for clarity.
parent
b221f9ae2b
commit
3a53e4ce46
BIN
src/blender_maxwell/assets/internal/monitor/_monitor_power_flux.blend (Stored with Git LFS)
BIN
src/blender_maxwell/assets/internal/monitor/_monitor_power_flux.blend (Stored with Git LFS)
Binary file not shown.
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,17 @@
|
||||||
|
from .array import ArrayFlow
|
||||||
|
from .capabiltiies import CapabilitiesFlow
|
||||||
|
from .flow_kinds import FlowKind
|
||||||
|
from .lazy_array_range import LazyArrayRange
|
||||||
|
from .lazy_value_func import LazyValueFunc
|
||||||
|
from .params import Params
|
||||||
|
from .value import ValueFlow
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ArrayFlow',
|
||||||
|
'CapabilitiesFlow',
|
||||||
|
'FlowKind',
|
||||||
|
'LazyArrayRange',
|
||||||
|
'LazyValueFunc',
|
||||||
|
'Params',
|
||||||
|
'ValueFlow',
|
||||||
|
]
|
|
@ -0,0 +1,96 @@
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
import jaxtyping as jtyp
|
||||||
|
import numpy as np
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class ArrayFlow:
|
||||||
|
"""A simple, flat array of values with an optionally-attached unit.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
values: An ND array-like object of arbitrary numerical type.
|
||||||
|
unit: A `sympy` unit.
|
||||||
|
None if unitless.
|
||||||
|
"""
|
||||||
|
|
||||||
|
values: jtyp.Shaped[jtyp.Array, '...']
|
||||||
|
unit: spux.Unit | None = None
|
||||||
|
|
||||||
|
is_sorted: bool = False
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.values)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def mathtype(self) -> spux.MathType:
|
||||||
|
return spux.MathType.from_pytype(type(self.values.item(0)))
|
||||||
|
|
||||||
|
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
||||||
|
"""Find the index of the value that is closest to the given value.
|
||||||
|
|
||||||
|
Units are taken into account; the given value will be scaled to the internal unit before direct use.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
require_sorted: Require that `self.values` be sorted, so that use of the faster binary-search algorithm is guaranteed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The index of `self.values` that is closest to the value `value`.
|
||||||
|
"""
|
||||||
|
if not require_sorted:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Scale Given Value to Internal Unit
|
||||||
|
scaled_value = spux.sympy_to_python(spux.scale_to_unit(value, self.unit))
|
||||||
|
|
||||||
|
# BinSearch for "Right IDX"
|
||||||
|
## >>> self.values[right_idx] > scaled_value
|
||||||
|
## >>> self.values[right_idx - 1] < scaled_value
|
||||||
|
right_idx = np.searchsorted(self.values, scaled_value, side='left')
|
||||||
|
|
||||||
|
# Case: Right IDX is Boundary
|
||||||
|
if right_idx == 0:
|
||||||
|
return right_idx
|
||||||
|
if right_idx == len(self.values):
|
||||||
|
return right_idx - 1
|
||||||
|
|
||||||
|
# Find Closest of [Right IDX - 1, Right IDX]
|
||||||
|
left_val = self.values[right_idx - 1]
|
||||||
|
right_val = self.values[right_idx]
|
||||||
|
|
||||||
|
if (scaled_value - left_val) <= (right_val - scaled_value):
|
||||||
|
return right_idx - 1
|
||||||
|
|
||||||
|
return right_idx
|
||||||
|
|
||||||
|
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self:
|
||||||
|
if self.unit is not None:
|
||||||
|
return ArrayFlow(
|
||||||
|
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def rescale_to_unit(self, unit: spu.Quantity) -> typ.Self:
|
||||||
|
if self.unit is not None:
|
||||||
|
return ArrayFlow(
|
||||||
|
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
|
||||||
|
unit=unit,
|
||||||
|
is_sorted=self.is_sorted, ## TODO: Can we really say that?
|
||||||
|
)
|
||||||
|
## TODO: Is this scaling numerically stable?
|
||||||
|
|
||||||
|
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def rescale_to_unit_system(self, unit: spu.Quantity) -> typ.Self:
|
||||||
|
raise NotImplementedError
|
|
@ -0,0 +1,39 @@
|
||||||
|
import dataclasses
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
from ..socket_types import SocketType
|
||||||
|
from .flow_kinds import FlowKind
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class CapabilitiesFlow:
|
||||||
|
socket_type: SocketType
|
||||||
|
active_kind: FlowKind
|
||||||
|
|
||||||
|
is_universal: bool = False
|
||||||
|
|
||||||
|
# == Constraint
|
||||||
|
must_match: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
# ∀b (b ∈ A) Constraint
|
||||||
|
## A: allow_any
|
||||||
|
## b∈B: present_any
|
||||||
|
allow_any: set[typ.Any] = dataclasses.field(default_factory=set)
|
||||||
|
present_any: set[typ.Any] = dataclasses.field(default_factory=set)
|
||||||
|
|
||||||
|
def is_compatible_with(self, other: typ.Self) -> bool:
|
||||||
|
return other.is_universal or (
|
||||||
|
self.socket_type == other.socket_type
|
||||||
|
and self.active_kind == other.active_kind
|
||||||
|
# == Constraint
|
||||||
|
and all(
|
||||||
|
name in other.must_match
|
||||||
|
and self.must_match[name] == other.must_match[name]
|
||||||
|
for name in self.must_match
|
||||||
|
)
|
||||||
|
# ∀b (b ∈ A) Constraint
|
||||||
|
and (
|
||||||
|
self.present_any & other.allow_any
|
||||||
|
or (not self.present_any and not self.allow_any)
|
||||||
|
)
|
||||||
|
)
|
|
@ -0,0 +1,68 @@
|
||||||
|
import enum
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlowKind(enum.StrEnum):
|
||||||
|
"""Defines a kind of data that can flow between nodes.
|
||||||
|
|
||||||
|
Each node link can be thought to contain **multiple pipelines for data to flow along**.
|
||||||
|
Each pipeline is cached incrementally, and independently, of the others.
|
||||||
|
Thus, the same socket can easily support several kinds of related data flow at the same time.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
Capabilities: Describes a socket's linkeability with other sockets.
|
||||||
|
Links between sockets with incompatible capabilities will be rejected.
|
||||||
|
This doesn't need to be defined normally, as there is a default.
|
||||||
|
However, in some cases, defining it manually to control linkeability more granularly may be desirable.
|
||||||
|
Value: A generic object, which is "directly usable".
|
||||||
|
This should be chosen when a more specific flow kind doesn't apply.
|
||||||
|
Array: An object with dimensions, and possibly a unit.
|
||||||
|
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
|
||||||
|
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
|
||||||
|
LazyValueFunc: A composable function.
|
||||||
|
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
|
||||||
|
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
|
||||||
|
This should be used instead of `Array` whenever possible.
|
||||||
|
Param: A dictionary providing particular parameters for a lazy value.
|
||||||
|
Info: An dictionary providing extra context about any aspect of flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Capabilities = enum.auto()
|
||||||
|
|
||||||
|
# Values
|
||||||
|
Value = enum.auto()
|
||||||
|
Array = enum.auto()
|
||||||
|
|
||||||
|
# Lazy
|
||||||
|
LazyValueFunc = enum.auto()
|
||||||
|
LazyArrayRange = enum.auto()
|
||||||
|
|
||||||
|
# Auxiliary
|
||||||
|
Params = enum.auto()
|
||||||
|
Info = enum.auto()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def scale_to_unit_system(
|
||||||
|
cls,
|
||||||
|
kind: typ.Self,
|
||||||
|
value,
|
||||||
|
unit_system: spux.UnitSystem,
|
||||||
|
):
|
||||||
|
if kind == cls.Value:
|
||||||
|
return spux.scale_to_unit_system(
|
||||||
|
value,
|
||||||
|
unit_system,
|
||||||
|
)
|
||||||
|
if kind == cls.LazyArrayRange:
|
||||||
|
return value.rescale_to_unit_system(unit_system)
|
||||||
|
|
||||||
|
if kind == cls.Params:
|
||||||
|
return value.rescale_to_unit_system(unit_system)
|
||||||
|
|
||||||
|
msg = 'Tried to scale unknown kind'
|
||||||
|
raise ValueError(msg)
|
|
@ -0,0 +1,182 @@
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
import jax
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
from .array import ArrayFlow
|
||||||
|
from .lazy_array_range import LazyArrayRangeFlow
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class InfoFlow:
|
||||||
|
# Dimension Information
|
||||||
|
dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||||
|
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
) ## TODO: Rename to dim_idxs
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_lens(self) -> dict[str, int]:
|
||||||
|
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_mathtypes(self) -> dict[str, spux.MathType]:
|
||||||
|
return {
|
||||||
|
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_units(self) -> dict[str, spux.Unit]:
|
||||||
|
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_physical_types(self) -> dict[str, spux.PhysicalType]:
|
||||||
|
return {
|
||||||
|
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
|
||||||
|
for dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_idx_arrays(self) -> list[jax.Array]:
|
||||||
|
return [
|
||||||
|
dim_idx.realize().values
|
||||||
|
if isinstance(dim_idx, LazyArrayRangeFlow)
|
||||||
|
else dim_idx.values
|
||||||
|
for dim_idx in self.dim_idx.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Output Information
|
||||||
|
## TODO: Add PhysicalType
|
||||||
|
output_name: str = dataclasses.field(default_factory=list)
|
||||||
|
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
|
||||||
|
output_mathtype: spux.MathType = dataclasses.field()
|
||||||
|
output_unit: spux.Unit | None = dataclasses.field()
|
||||||
|
|
||||||
|
# Pinned Dimension Information
|
||||||
|
## TODO: Add PhysicalType
|
||||||
|
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||||
|
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Methods
|
||||||
|
####################
|
||||||
|
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=self.dim_names,
|
||||||
|
dim_idx={
|
||||||
|
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
|
||||||
|
for _dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_dimension(self, dim_name: str) -> typ.Self:
|
||||||
|
"""Delete a dimension."""
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=[
|
||||||
|
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
|
||||||
|
],
|
||||||
|
dim_idx={
|
||||||
|
_dim_name: dim_idx
|
||||||
|
for _dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
if _dim_name != dim_name
|
||||||
|
},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
|
||||||
|
"""Delete a dimension."""
|
||||||
|
|
||||||
|
# Compute Swapped Dimension Name List
|
||||||
|
def name_swapper(dim_name):
|
||||||
|
return (
|
||||||
|
dim_name
|
||||||
|
if dim_name not in [dim_0_name, dim_1_name]
|
||||||
|
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
|
||||||
|
|
||||||
|
# Compute Info
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=dim_names,
|
||||||
|
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
|
||||||
|
"""Set the MathType of a particular output name."""
|
||||||
|
return InfoFlow(
|
||||||
|
dim_names=self.dim_names,
|
||||||
|
dim_idx=self.dim_idx,
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collapse_output(
|
||||||
|
self,
|
||||||
|
collapsed_name: str,
|
||||||
|
collapsed_mathtype: spux.MathType,
|
||||||
|
collapsed_unit: spux.Unit,
|
||||||
|
) -> typ.Self:
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=self.dim_names,
|
||||||
|
dim_idx=self.dim_idx,
|
||||||
|
output_name=collapsed_name,
|
||||||
|
output_shape=None,
|
||||||
|
output_mathtype=collapsed_mathtype,
|
||||||
|
output_unit=collapsed_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def shift_last_input(self):
|
||||||
|
"""Shift the last input dimension to the output."""
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=self.dim_names[:-1],
|
||||||
|
dim_idx={
|
||||||
|
dim_name: dim_idx
|
||||||
|
for dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
if dim_name != self.dim_names[-1]
|
||||||
|
},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=(
|
||||||
|
(self.dim_lens[self.dim_names[-1]],)
|
||||||
|
if self.output_shape is None
|
||||||
|
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
|
||||||
|
),
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
|
@ -0,0 +1,347 @@
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import typing as typ
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxtyping as jtyp
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
|
||||||
|
from .array import ArrayFlow
|
||||||
|
from .flow_kinds import FlowKind
|
||||||
|
from .lazy_value_func import LazyValueFuncFlow
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class LazyArrayRangeFlow:
|
||||||
|
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation.
|
||||||
|
|
||||||
|
# Advantages
|
||||||
|
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||||
|
|
||||||
|
## Memory
|
||||||
|
`ArrayFlow` generally has a memory scaling of $O(n)$.
|
||||||
|
Naturally, `LazyArrayRangeFlow` is always constant, since only the boundaries and steps are stored.
|
||||||
|
|
||||||
|
## Symbolic
|
||||||
|
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol).
|
||||||
|
|
||||||
|
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow.
|
||||||
|
|
||||||
|
## Performant Unit-Aware Operations
|
||||||
|
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
|
||||||
|
`LazyArrayRangeFlow`, by contrast, scales as $O(1)$.
|
||||||
|
|
||||||
|
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
|
||||||
|
|
||||||
|
## High-Performance Composition and Gradiant
|
||||||
|
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
|
||||||
|
There are two nice things about this:
|
||||||
|
|
||||||
|
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
|
||||||
|
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
|
||||||
|
|
||||||
|
Thus, as part of larger computations, the performance properties of `LazyArrayRangeFlow` is extremely favorable.
|
||||||
|
|
||||||
|
## Numerical Properties
|
||||||
|
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
start: An expression generating a scalar, unitless, complex value for the array's lower bound.
|
||||||
|
_Integer, rational, and real values are also supported._
|
||||||
|
stop: An expression generating a scalar, unitless, complex value for the array's upper bound.
|
||||||
|
_Integer, rational, and real values are also supported._
|
||||||
|
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
|
||||||
|
scaling: The method of distributing `step` values between the two endpoints.
|
||||||
|
Generally, the linear default is sufficient.
|
||||||
|
|
||||||
|
unit: The unit of the generated array values
|
||||||
|
|
||||||
|
symbols: Set of variables from which `start` and/or `stop` are determined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start: spux.ScalarUnitlessComplexExpr
|
||||||
|
stop: spux.ScalarUnitlessComplexExpr
|
||||||
|
steps: int
|
||||||
|
scaling: typ.Literal['lin', 'geom', 'log'] = 'lin'
|
||||||
|
|
||||||
|
unit: spux.Unit | None = None
|
||||||
|
|
||||||
|
symbols: frozenset[spux.IntSymbol] = frozenset()
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||||
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
|
The order is guaranteed to be **deterministic**.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def mathtype(self) -> spux.MathType:
|
||||||
|
# Get Start Mathtype
|
||||||
|
if isinstance(self.start, spux.SympyType):
|
||||||
|
start_mathtype = spux.MathType.from_expr(self.start)
|
||||||
|
else:
|
||||||
|
start_mathtype = spux.MathType.from_pytype(type(self.start))
|
||||||
|
|
||||||
|
# Get Stop Mathtype
|
||||||
|
if isinstance(self.stop, spux.SympyType):
|
||||||
|
stop_mathtype = spux.MathType.from_expr(self.stop)
|
||||||
|
else:
|
||||||
|
stop_mathtype = spux.MathType.from_pytype(type(self.stop))
|
||||||
|
|
||||||
|
# Check Equal
|
||||||
|
if start_mathtype != stop_mathtype:
|
||||||
|
return spux.MathType.combine(start_mathtype, stop_mathtype)
|
||||||
|
|
||||||
|
return start_mathtype
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Units
|
||||||
|
####################
|
||||||
|
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
|
||||||
|
"""Replaces the unit without rescaling the unitless bounds.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
corrected_unit: The unit to replace the current unit with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new `LazyArrayRangeFlow` with replaced unit.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
|
"""
|
||||||
|
if self.unit is not None:
|
||||||
|
return LazyArrayRangeFlow(
|
||||||
|
start=self.start,
|
||||||
|
stop=self.stop,
|
||||||
|
steps=self.steps,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=corrected_unit,
|
||||||
|
symbols=self.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
||||||
|
"""Replaces the unit, **with** rescaling of the bounds.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unit: The unit to convert the bounds to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new `LazyArrayRangeFlow` with replaced unit.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
|
"""
|
||||||
|
if self.unit is not None:
|
||||||
|
return LazyArrayRangeFlow(
|
||||||
|
start=spux.scale_to_unit(self.start * self.unit, unit),
|
||||||
|
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
||||||
|
steps=self.steps,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=unit,
|
||||||
|
symbols=self.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||||
|
"""Replaces the units, **with** rescaling of the bounds.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unit: The unit to convert the bounds to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new `LazyArrayRangeFlow` with replaced unit.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
|
"""
|
||||||
|
if self.unit is not None:
|
||||||
|
return LazyArrayRangeFlow(
|
||||||
|
start=spux.strip_unit_system(
|
||||||
|
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
||||||
|
unit_system,
|
||||||
|
),
|
||||||
|
stop=spux.strip_unit_system(
|
||||||
|
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
||||||
|
unit_system,
|
||||||
|
),
|
||||||
|
steps=self.steps,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||||
|
symbols=self.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Bound Operations
|
||||||
|
####################
|
||||||
|
def rescale_bounds(
|
||||||
|
self,
|
||||||
|
rescale_func: typ.Callable[
|
||||||
|
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
|
||||||
|
],
|
||||||
|
reverse: bool = False,
|
||||||
|
) -> typ.Self:
|
||||||
|
"""Apply a function to the bounds, effectively rescaling the represented array.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
**It is presumed that the bounds are scaled with the same factor**.
|
||||||
|
Breaking this presumption may have unexpected results.
|
||||||
|
|
||||||
|
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
scaler: The function that scales each bound.
|
||||||
|
reverse: Whether to reverse the bounds after running the `scaler`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A rescaled `LazyArrayRangeFlow`.
|
||||||
|
"""
|
||||||
|
return LazyArrayRangeFlow(
|
||||||
|
start=rescale_func(self.start if not reverse else self.stop),
|
||||||
|
stop=rescale_func(self.stop if not reverse else self.start),
|
||||||
|
steps=self.steps,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=self.unit,
|
||||||
|
symbols=self.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Lazy Representation
|
||||||
|
####################
|
||||||
|
@functools.cached_property
|
||||||
|
def array_generator(
|
||||||
|
self,
|
||||||
|
) -> typ.Callable[
|
||||||
|
[int | float | complex, int | float | complex, int],
|
||||||
|
jtyp.Inexact[jtyp.Array, ' steps'],
|
||||||
|
]:
|
||||||
|
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
|
||||||
|
"""
|
||||||
|
jnp_nspace = {
|
||||||
|
'lin': jnp.linspace,
|
||||||
|
'geom': jnp.geomspace,
|
||||||
|
'log': jnp.logspace,
|
||||||
|
}.get(self.scaling)
|
||||||
|
if jnp_nspace is None:
|
||||||
|
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return jnp_nspace
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def as_func(
|
||||||
|
self,
|
||||||
|
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
|
||||||
|
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `LazyValueFuncFlow` that, given the input symbols defined in `self.symbols`,
|
||||||
|
"""
|
||||||
|
# Compile JAX Functions for Start/End Expressions
|
||||||
|
## FYI, JAX-in-JAX works perfectly fine.
|
||||||
|
start_jax = sp.lambdify(self.symbols, self.start, 'jax')
|
||||||
|
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax')
|
||||||
|
|
||||||
|
# Compile ArrayGen Function
|
||||||
|
def gen_array(
|
||||||
|
*args: list[int | float | complex],
|
||||||
|
) -> jtyp.Inexact[jtyp.Array, ' steps']:
|
||||||
|
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
|
||||||
|
|
||||||
|
# Return ArrayGen Function
|
||||||
|
return gen_array
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def as_lazy_value_func(self) -> LazyValueFuncFlow:
|
||||||
|
"""Creates a `LazyValueFuncFlow` using the output of `self.as_func`.
|
||||||
|
|
||||||
|
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The the function enclosed in the `LazyValueFuncFlow` is identical to the one returned by `self.as_func`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `LazyValueFuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
|
||||||
|
"""
|
||||||
|
return LazyValueFuncFlow(
|
||||||
|
func=self.as_func,
|
||||||
|
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Realization
|
||||||
|
####################
|
||||||
|
def realize(
|
||||||
|
self,
|
||||||
|
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||||
|
kind: typ.Literal[FlowKind.Array, FlowKind.LazyValueFunc] = FlowKind.Array,
|
||||||
|
) -> ArrayFlow | LazyValueFuncFlow:
|
||||||
|
"""Apply a function to the bounds, effectively rescaling the represented array.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
**It is presumed that the bounds are scaled with the same factor**.
|
||||||
|
Breaking this presumption may have unexpected results.
|
||||||
|
|
||||||
|
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
scaler: The function that scales each bound.
|
||||||
|
reverse: Whether to reverse the bounds after running the `scaler`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A rescaled `LazyArrayRangeFlow`.
|
||||||
|
"""
|
||||||
|
if not set(self.symbols).issubset(set(symbol_values.keys())):
|
||||||
|
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Realize Symbols
|
||||||
|
realized_start = spux.sympy_to_python(
|
||||||
|
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||||
|
)
|
||||||
|
realized_stop = spux.sympy_to_python(
|
||||||
|
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return Linspace / Logspace
|
||||||
|
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
|
||||||
|
return self.array_generator(realized_start, realized_stop, self.steps)
|
||||||
|
|
||||||
|
if kind == FlowKind.Array:
|
||||||
|
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
|
||||||
|
if kind == FlowKind.LazyValueFunc:
|
||||||
|
return LazyValueFuncFlow(func=gen_array, supports_jax=True)
|
||||||
|
|
||||||
|
msg = f'Invalid kind: {kind}'
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def realize_array(self) -> ArrayFlow:
|
||||||
|
return self.realize()
|
|
@ -0,0 +1,194 @@
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import typing as typ
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
import jax
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class LazyValueFuncFlow:
|
||||||
|
r"""Wraps a composable function, providing useful information and operations.
|
||||||
|
|
||||||
|
# Data Flow as Function Composition
|
||||||
|
When using nodes to do math, it can be a good idea to express a **flow of data as the composition of functions**.
|
||||||
|
|
||||||
|
Each node creates a new function, which uses the still-unknown (aka. **lazy**) output of the previous function to plan some calculations.
|
||||||
|
Some new arguments may also be added, of course.
|
||||||
|
|
||||||
|
## Root Function
|
||||||
|
Of course, one needs to select a "bottom" function, which has no previous function as input.
|
||||||
|
Thus, the first step is to define this **root function**:
|
||||||
|
|
||||||
|
$$
|
||||||
|
f_0:\ \ \ \ \biggl(
|
||||||
|
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
|
||||||
|
\underbrace{
|
||||||
|
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
|
||||||
|
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
|
||||||
|
...,
|
||||||
|
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
|
||||||
|
}_{\texttt{kwargs}}
|
||||||
|
\biggr) \to \text{output}_0
|
||||||
|
$$
|
||||||
|
|
||||||
|
We'll express this simple snippet like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
|
||||||
|
## 'A0', 'KV0' are of length 'p' and 'q'
|
||||||
|
def f_0(*args, **kwargs): ...
|
||||||
|
|
||||||
|
lazy_value_func_0 = LazyValueFuncFlow(
|
||||||
|
func=f_0,
|
||||||
|
func_args=[(a_i, type(a_i)) for a_i in A0],
|
||||||
|
func_kwargs={k: v for k,v in KV0},
|
||||||
|
)
|
||||||
|
output_0 = lazy_value_func.func(*A0_computed, **KV0_computed)
|
||||||
|
```
|
||||||
|
|
||||||
|
So far so good.
|
||||||
|
But of course, nothing interesting has really happened yet.
|
||||||
|
|
||||||
|
## Composing Functions
|
||||||
|
The key thing is the next step: The function that uses the result of $f_0$!
|
||||||
|
|
||||||
|
$$
|
||||||
|
f_1:\ \ \ \ \biggl(
|
||||||
|
f_0(...),\ \
|
||||||
|
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
|
||||||
|
\underbrace{\biggl\{
|
||||||
|
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
|
||||||
|
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
|
||||||
|
\biggr) \to \text{output}_1
|
||||||
|
$$
|
||||||
|
|
||||||
|
Notice that _$f_1$ needs the arguments of both $f_0$ and $f_1$_.
|
||||||
|
Tracking arguments is already getting out of hand; we already have to use `...` to keep it readeable!
|
||||||
|
|
||||||
|
But doing so with `LazyValueFunc` is not so complex:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
|
||||||
|
## 'A1', 'KV1' are therefore of length 'r' and 's'
|
||||||
|
def f_1(output_0, *args, **kwargs): ...
|
||||||
|
|
||||||
|
lazy_value_func_1 = lazy_value_func_0.compose_within(
|
||||||
|
enclosing_func=f_1,
|
||||||
|
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
|
||||||
|
enclosing_func_kwargs={k: type(v) for k,v in K1},
|
||||||
|
)
|
||||||
|
|
||||||
|
A_computed = A0_computed + A1_computed
|
||||||
|
KW_computed = KV0_computed + KV1_computed
|
||||||
|
output_1 = lazy_value_func_1.func(*A_computed, **KW_computed)
|
||||||
|
```
|
||||||
|
|
||||||
|
We only need the arguments to $f_1$, and `LazyValueFunc` figures out how to make one function with enough arguments to call both.
|
||||||
|
|
||||||
|
## Isn't Laying Functions Slow/Hard?
|
||||||
|
Imagine that each function represents the action of a node, each of which performs expensive calculations on huge `numpy` arrays (**as one does when processing electromagnetic field data**).
|
||||||
|
At the end, a node might run the entire procedure with all arguments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
output_n = lazy_value_func_n.func(*A_all, **KW_all)
|
||||||
|
```
|
||||||
|
|
||||||
|
It's rough: Most non-trivial pipelines drown in the time/memory overhead of incremental `numpy` operations - individually fast, but collectively iffy.
|
||||||
|
|
||||||
|
The killer feature of `LazyValueFuncFlow` is a sprinkle of black magic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
func_n_jax = lazy_value_func_n.func_jax
|
||||||
|
output_n = func_n_jax(*A_all, **KW_all) ## Runs on your GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
What happened was, **the entire pipeline** was compiled and optimized for high performance on not just your CPU, _but also (possibly) your GPU_.
|
||||||
|
All the layered function calls and inefficient incremental processing is **transformed into a high-performance program**.
|
||||||
|
|
||||||
|
Thank `jax` - specifically, `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), which internally enables this magic with a single function call.
|
||||||
|
|
||||||
|
## Other Considerations
|
||||||
|
**Auto-Differentiation**: Incredibly, `jax.jit` isn't the killer feature of `jax`. The function that comes out of `LazyValueFuncFlow` can also be differentiated with `jax.grad` (read: high-performance Jacobians for optimizing input parameters).
|
||||||
|
|
||||||
|
Though designed for machine learning, there's no reason other fields can't enjoy their inventions!
|
||||||
|
|
||||||
|
**Impact of Independent Caching**: JIT'ing can be slow.
|
||||||
|
That's why `LazyValueFuncFlow` has its own `FlowKind` "lane", which means that **only changes to the processing procedures will cause recompilation**.
|
||||||
|
|
||||||
|
Generally, adjustable values that affect the output will flow via the `Param` "lane", which has its own incremental caching, and only meets the compiled function when it's "plugged in" for final evaluation.
|
||||||
|
The effect is a feeling of snappiness and interactivity, even as the volume of data grows.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
func: The function that the object encapsulates.
|
||||||
|
bound_args: Arguments that will be packaged into function, which can't be later modifier.
|
||||||
|
func_kwargs: Arguments to be specified by the user at the time of use.
|
||||||
|
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
func: LazyFunction
|
||||||
|
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
supports_jax: bool = False
|
||||||
|
|
||||||
|
# Merging
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: typ.Self,
|
||||||
|
):
|
||||||
|
return LazyValueFuncFlow(
|
||||||
|
func=lambda *args, **kwargs: (
|
||||||
|
self.func(
|
||||||
|
*list(args[: len(self.func_args)]),
|
||||||
|
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||||
|
),
|
||||||
|
other.func(
|
||||||
|
*list(args[len(self.func_args) :]),
|
||||||
|
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
func_args=self.func_args + other.func_args,
|
||||||
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
|
supports_jax=self.supports_jax and other.supports_jax,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Composition
|
||||||
|
def compose_within(
|
||||||
|
self,
|
||||||
|
enclosing_func: LazyFunction,
|
||||||
|
enclosing_func_args: list[type] = (),
|
||||||
|
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||||
|
supports_jax: bool = False,
|
||||||
|
) -> typ.Self:
|
||||||
|
return LazyValueFuncFlow(
|
||||||
|
func=lambda *args, **kwargs: enclosing_func(
|
||||||
|
self.func(
|
||||||
|
*list(args[: len(self.func_args)]),
|
||||||
|
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||||
|
),
|
||||||
|
*args[len(self.func_args) :],
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
|
||||||
|
),
|
||||||
|
func_args=self.func_args + list(enclosing_func_args),
|
||||||
|
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||||
|
supports_jax=self.supports_jax and supports_jax,
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def func_jax(self) -> LazyFunction:
|
||||||
|
if self.supports_jax:
|
||||||
|
return jax.jit(self.func)
|
||||||
|
|
||||||
|
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
|
||||||
|
raise ValueError(msg)
|
|
@ -0,0 +1,95 @@
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import typing as typ
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamsFlow:
|
||||||
|
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
||||||
|
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
symbols: frozenset[spux.Symbol] = frozenset()
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||||
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Scaled Func Args
|
||||||
|
####################
|
||||||
|
def scaled_func_args(
|
||||||
|
self,
|
||||||
|
unit_system: spux.UnitSystem,
|
||||||
|
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||||
|
):
|
||||||
|
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
||||||
|
if not all(sym in self.symbols for sym in symbol_values):
|
||||||
|
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return [
|
||||||
|
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
|
||||||
|
if arg not in symbol_values
|
||||||
|
else symbol_values[arg]
|
||||||
|
for arg in self.func_args
|
||||||
|
]
|
||||||
|
|
||||||
|
def scaled_func_kwargs(
|
||||||
|
self,
|
||||||
|
unit_system: spux.UnitSystem,
|
||||||
|
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||||
|
):
|
||||||
|
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
||||||
|
if not all(sym in self.symbols for sym in symbol_values):
|
||||||
|
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return {
|
||||||
|
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
|
||||||
|
if arg not in symbol_values
|
||||||
|
else symbol_values[arg]
|
||||||
|
for arg_name, arg in self.func_kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Operations
|
||||||
|
####################
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: typ.Self,
|
||||||
|
):
|
||||||
|
"""Combine two function parameter lists, such that the LHS will be concatenated with the RHS.
|
||||||
|
|
||||||
|
Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters.
|
||||||
|
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
|
||||||
|
"""
|
||||||
|
return ParamsFlow(
|
||||||
|
func_args=self.func_args + other.func_args,
|
||||||
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
|
symbols=self.symbols | other.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compose_within(
|
||||||
|
self,
|
||||||
|
enclosing_func_args: list[spux.SympyExpr] = (),
|
||||||
|
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
|
||||||
|
enclosing_symbols: frozenset[spux.Symbol] = frozenset(),
|
||||||
|
) -> typ.Self:
|
||||||
|
return ParamsFlow(
|
||||||
|
func_args=self.func_args + list(enclosing_func_args),
|
||||||
|
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||||
|
symbols=self.symbols | enclosing_symbols,
|
||||||
|
)
|
|
@ -0,0 +1,3 @@
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
ValueFlow: typ.TypeAlias = typ.Any
|
|
@ -116,7 +116,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
||||||
size=input_sockets['Size'],
|
size=input_sockets['Size'],
|
||||||
name=props['sim_node_name'],
|
name=props['sim_node_name'],
|
||||||
interval_space=(1, 1, 1),
|
interval_space=(1, 1, 1),
|
||||||
freqs=input_sockets['Freqs'].realize().values,
|
freqs=input_sockets['Freqs'].realize_array,
|
||||||
normal_dir='+' if input_sockets['Direction'] else '-',
|
normal_dir='+' if input_sockets['Direction'] else '-',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -142,11 +142,11 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
# Trigger
|
# Trigger
|
||||||
socket_name={'Center', 'Size'},
|
socket_name={'Center', 'Size', 'Direction'},
|
||||||
run_on_init=True,
|
run_on_init=True,
|
||||||
# Loaded
|
# Loaded
|
||||||
managed_objs={'mesh', 'modifier'},
|
managed_objs={'mesh', 'modifier'},
|
||||||
input_sockets={'Center', 'Size'},
|
input_sockets={'Center', 'Size', 'Direction'},
|
||||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||||
scale_input_sockets={
|
scale_input_sockets={
|
||||||
'Center': 'BlenderUnits',
|
'Center': 'BlenderUnits',
|
||||||
|
@ -167,6 +167,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
||||||
'unit_system': unit_systems['BlenderUnits'],
|
'unit_system': unit_systems['BlenderUnits'],
|
||||||
'inputs': {
|
'inputs': {
|
||||||
'Size': input_sockets['Size'],
|
'Size': input_sockets['Size'],
|
||||||
|
'Direction': input_sockets['Direction'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue