feat: various sym-flow modifications

main
Sofus Albert Høgsbro Rose 2024-05-30 18:41:06 +02:00
parent 830b316e01
commit 38e70a60d3
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
40 changed files with 2226 additions and 835 deletions

View File

@ -14,12 +14,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import typing as typ
import jaxtyping as jtyp
import numpy as np
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
@ -29,8 +29,7 @@ log = logger.get(__name__)
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
@dataclasses.dataclass(frozen=True, kw_only=True)
class ArrayFlow:
class ArrayFlow(pyd.BaseModel):
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
@ -41,7 +40,9 @@ class ArrayFlow:
None if unitless.
"""
values: jtyp.Shaped[jtyp.Array, '...']
model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
unit: spux.Unit | None = None
is_sorted: bool = False

View File

@ -18,6 +18,7 @@ import enum
import functools
import typing as typ
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty
@ -99,6 +100,17 @@ class FlowKind(enum.StrEnum):
def to_icon(_: typ.Self) -> str:
return ''
@property
def name(self) -> str:
return FlowKind.to_name(self)
@property
def icon(self) -> str:
return FlowKind.to_icon(self)
def bl_enum_element(self, i) -> BLEnumElement:
return (str(self), self.name, self.name, self.icon, i)
####################
# - Static Properties
####################
@ -162,7 +174,7 @@ class FlowKind(enum.StrEnum):
def socket_shape(self) -> str:
"""Return the socket shape associated with this `FlowKind`.
**ONLY** valid for `FlowKind`s that can be considered "active".
Should generally only be used with `active_kinds`.
Raises:
ValueError: If this `FlowKind` cannot ever be considered "active".
@ -172,7 +184,7 @@ class FlowKind(enum.StrEnum):
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
}.get(self, 'CIRCLE')
####################
# - Class Methods

View File

@ -14,40 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import jax
import jaxtyping as jtyp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .info import InfoFlow
from .lazy_range import RangeFlow
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class FuncFlow:
r"""Defines a flow of data as incremental function composition.
For specific math system usage instructions, please consult the documentation of relevant nodes.
r"""Implements the core of the math system via `FuncFlow`, which allows high-performance, fully-expressive workflows with data that can be "very large", and/or whose input parameters are not yet fully known.
# Introduction
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
Doing so has several advantages:
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy.
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact.
- **Performant**: Since no operations are happening, the UI feels fast and snappy.
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy, greatly boosting the will to experiment and ultimate productivity.
- **Symbolic**: Since no numerical math is being done yet, we can inject symbolic variables at-will, enabling effortless ex. parameter sweeping, band-structure generation, differentiable can choose to keep our input parameters as symbolic variables with no performance impact.
- **Performant**: Since the data pipeline is built as a single function w/o side effects, that function can be often be JIT-optimized for highly-optimized execution on the instruction sets used by modern massively-parallel devices, like modern CPUs (SSE/AVX), GPUs (ex. PTX), and HPC clusters (network-sharding to ARM/x86).
The result is a math system optimized for the analysis typically needed in electrodynamic contexts, prioritizing clarity and flexibility at soft-real-time, even with gigabytes of data on relatively weak hardware.
## Strongly Related FlowKinds
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
@ -238,6 +215,35 @@ class FuncFlow:
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
But above all, we hope that this math system is fun, practical, and maybe even interesting.
"""
import functools
import typing as typ
from types import MappingProxyType
import jax
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .info import InfoFlow
from .lazy_range import RangeFlow
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
class FuncFlow(pyd.BaseModel):
r"""Defines a flow of data as incremental function composition.
For theoretical information, please see the documentation of this module.
For specific math system usage instructions, please consult the documentation of relevant nodes.
Attributes:
func: The function that generates the represented value.
@ -247,14 +253,16 @@ class FuncFlow:
See the documentation of `self.func_jax()`.
"""
model_config = pyd.ConfigDict(frozen=True)
func: LazyFunction
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
default_factory=dict
)
func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
func_output: sim_symbols.SimSymbol | None = None
supports_jax: bool = False
concatenated: bool = False
is_concatenated: bool = False
####################
# - Functions
@ -318,6 +326,7 @@ class FuncFlow:
{}
),
) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(symbol_values),
@ -371,14 +380,55 @@ class FuncFlow:
return data | {info.output: self.realize(params, symbol_values=symbol_values)}
def realize_partial(
self, params: ParamsFlow
) -> typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
jtyp.Inexact[jtyp.Array, '...'],
]:
"""Create a purely-numerical function, which takes only numerical.
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
By using `realize_partial()`, two things are ensured:
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
Notes:
Be **very careful about units**.
Ideally, the bottom function should use `.scale_to_unit()` before invoking `.compose_within()` with the output of this function.
"""
pre_realized_syms = list(
params.realize_symbols(params.realized_symbols, allow_partial=True).values()
)
def realizer(
*sym_args: int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
) -> jtyp.Inexact[jtyp.Array, '...']:
return self.func(
*[
func_arg_n(*sym_args, *pre_realized_syms)
for func_arg_n in params.func_args_n
],
**{
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
},
)
return realizer
####################
# - Composition Operations
# - Operations
####################
def compose_within(
self,
enclosing_func: LazyFunction,
enclosing_func_args: list[type] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
enclosing_func_args: list[sim_symbols.SimSymbol] = (),
enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
enclosing_func_output: sim_symbols.SimSymbol | None = None,
supports_jax: bool = False,
) -> typ.Self:
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
@ -415,6 +465,10 @@ class FuncFlow:
Returns:
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
"""
## TODO: Support unit system conversion at the point of composition.
## -- This may require us to track the units of the function output.
## TODO: Support JAX-evaluation when jax support changes from True to False.
## -- This would allow big data flows to compose performantly as arguments into non-JAX functions.
return FuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
@ -426,6 +480,7 @@ class FuncFlow:
),
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
func_output=enclosing_func_output,
supports_jax=self.supports_jax and supports_jax,
)
@ -472,7 +527,7 @@ class FuncFlow:
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
)
if not self.concatenated:
if not self.is_concatenated:
return (ret,)
return ret
@ -487,5 +542,83 @@ class FuncFlow:
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
concatenated=True,
is_concatenated=True,
)
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
`unit` must be manually guaranteed to be compatible with `self.unit`.
"""
if self.func_output is not None:
# Retrieve Output Unit
output_unit = self.func_output.unit
# Compile Efficient Unit-Conversion Function
a = self.func_output.mathtype.sp_symbol_a
unit_convert_expr = (
spux.scale_to_unit(a * output_unit, unit)
if self.func_output.unit is not None
else a
)
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
# Compose Unit-Converted FuncFlow
return self.compose_within(
enclosing_func=unit_convert_func,
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=unit),
)
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)
def scale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
Using `self.output_symbol`, which tracks the units of the output, we can determine a scaling factor to multiply the (numerical) function output by in order to conform it to the given unit system.
In general, **don't use this**.
Any superfluous numerical operations in a data pipeline can enhance instabilities and interfere with JIT-optimization (floating-point arithmetic isn't commutative, for example).
However, occasionally, we need to "intercept" a lazy data flow, for example when realizing a `FlowKind.Value` that doesn't understand symbols or units - but which only accepts a float/complex scalar/array with pre-determined unit convention.
For this purpose alone, this method is provided to pre-scale a `FuncFlow`, just before using `realize()` / `__or__` and then `realize()`.
**To encourage proper usage** (and ease implementation), the output unit in `self.func_output` of the output will be reset to `None` - indicating that the output can only be handled as a unitless scalar w/semantic meaning tracked elsewhere.
Notes:
**ONLY** use with output types that support meaningful arbitrary multiplication.
A scale-only sympy expression will be used to produce an optimized JAX function of a single variable, which will then be composed onto the existing `FuncFlow`.
Parameters:
unit_system: The unit system to conform the function output to.
Returns:
A new `FuncFlow` that conforms to the new unit, but is itself now considered unitless.
"""
if self.func_output is not None:
# Retrieve Output Unit
output_unit = self.func_output.unit
# Compile Efficient Unit-Conversion Function
a = self.func_output.mathtype.sp_symbol_a
unit_convert_expr = (
spux.strip_unit_system(
spux.convert_to_unit_system(a * output_unit, unit_system)
)
if self.func_output.unit is not None
else a
)
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
# Compose Unit-Converted FuncFlow
return self.compose_within(
enclosing_func=unit_convert_func,
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=None),
)
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)

View File

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import enum
import functools
import typing as typ
from fractions import Fraction
from types import MappingProxyType
import jax.numpy as jnp
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
@ -61,8 +59,7 @@ class ScalingMode(enum.StrEnum):
return ''
@dataclasses.dataclass(frozen=True, kw_only=True)
class RangeFlow:
class RangeFlow(pyd.BaseModel):
r"""Represents a finite spaced array using symbolic boundary expressions.
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
@ -92,8 +89,10 @@ class RangeFlow:
symbols: Set of variables from which `start` and/or `stop` are determined.
"""
start: spux.ScalarUnitlessComplexExpr
stop: spux.ScalarUnitlessComplexExpr
model_config = pyd.ConfigDict(frozen=True)
start: spux.ScalarUnitlessRealExpr
stop: spux.ScalarUnitlessRealExpr
steps: int = 0
scaling: ScalingMode = ScalingMode.Lin
@ -102,7 +101,7 @@ class RangeFlow:
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
# Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
####################
# - SimSymbol Interop
@ -218,14 +217,26 @@ class RangeFlow:
)
return combined_mathtype
@property
@functools.cached_property
def ideal_midpoint(self) -> spux.SympyExpr:
return (self.stop + self.start) / 2
@property
@functools.cached_property
def ideal_range(self) -> spux.SympyExpr:
return self.stop - self.start
@functools.cached_property
def ideal_step_size(self) -> spux.SympyExpr:
return self.ideal_range / (self.steps - 1)
@functools.cached_property
def is_always_nonzero(self) -> spux.SympyExpr:
if self.start > 0 or self.stop < 0:
return True
is_zero = (self.start % self.ideal_step_size).is_zero
return is_zero if is_zero is not None else False
####################
# - Methods
####################
@ -452,7 +463,7 @@ class RangeFlow:
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
"""Realize **all** input symbols to the `RangeFlow`.
Parameters:
@ -480,7 +491,7 @@ class RangeFlow:
raise NotImplementedError(msg)
realized_syms |= {sym: v}
return realized_syms
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)

View File

@ -14,13 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import typing as typ
from fractions import Fraction
from types import MappingProxyType
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
@ -28,31 +28,35 @@ from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .expr_info import ExprInfo
from .flow_kinds import FlowKind
from .lazy_range import RangeFlow
log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow:
class ParamsFlow(pyd.BaseModel):
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
arg_targets: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
kwarg_targets: list[str, sim_symbols.SimSymbol] = dataclasses.field(
default_factory=dict
)
model_config = pyd.ConfigDict(frozen=True)
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
realized_symbols: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = pyd.Field(default_factory=dict)
is_differentiable: bool = False
@functools.cached_property
def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
"""Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
return {sym for sym in self.symbols if sym.can_diff}
####################
# - Symbols
@ -78,6 +82,27 @@ class ParamsFlow:
"""
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
@functools.cached_property
def all_sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
key_func = lambda sym: sym.name # noqa: E731
return sorted(self.symbols, key=key_func) + sorted(
self.realized_symbols.keys(), key=key_func
)
@functools.cached_property
def all_sorted_sp_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol_matsym for sym in self.all_sorted_symbols]
####################
# - JIT'ed Callables for Numerical Function Arguments
####################
@ -101,7 +126,7 @@ class ParamsFlow:
"""
return [
sp.lambdify(
self.sorted_sp_symbols,
self.all_sorted_sp_symbols,
target_sym.conform(func_arg, strip_unit=True),
'jax',
)
@ -127,7 +152,7 @@ class ParamsFlow:
"""
return {
key: sp.lambdify(
self.sorted_sp_symbols,
self.all_sorted_sp_symbols,
self.kwarg_targets[key].conform(func_arg, strip_unit=True),
'jax',
)
@ -142,8 +167,9 @@ class ParamsFlow:
symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}),
allow_partial: bool = False,
) -> dict[
sp.Symbol,
sim_symbols.SimSymbol,
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
]:
"""Fully realize all symbols by assigning them a value.
@ -160,10 +186,12 @@ class ParamsFlow:
Returns:
A dictionary almost with `.subs()`, other than `jax` arrays.
"""
if set(self.symbols) == set(symbol_values.keys()):
if allow_partial or set(self.all_sorted_symbols) == set(symbol_values.keys()):
realized_syms = {}
for sym in self.sorted_symbols:
sym_value = symbol_values[sym]
for sym in self.all_sorted_symbols:
sym_value = symbol_values.get(sym)
if sym_value is None and allow_partial:
continue
if isinstance(sym_value, spux.SympyType):
v = sym.scale(sym_value)
@ -214,7 +242,9 @@ class ParamsFlow:
Parameters:
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
"""
realized_symbols = list(self.realize_symbols(symbol_values).values())
realized_symbols = list(
self.realize_symbols(symbol_values | self.realized_symbols).values()
)
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
def scaled_func_kwargs(
@ -227,10 +257,11 @@ class ParamsFlow:
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
"""
realized_symbols = self.realize_symbols(symbol_values)
realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
return {
func_arg_name: func_arg_n(**realized_symbols)
for func_arg_name, func_arg_n in self.func_kwargs_n.items()
func_arg_name: func_kwarg_n(**realized_symbols)
for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
}
####################
@ -251,7 +282,7 @@ class ParamsFlow:
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
symbols=self.symbols | other.symbols,
is_differentiable=self.is_differentiable and other.is_differentiable,
realized_symbols=self.realized_symbols | other.realized_symbols,
)
def compose_within(
@ -261,7 +292,6 @@ class ParamsFlow:
enclosing_func_args: list[spux.SympyExpr] = (),
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
enclosing_is_differentiable: bool = False,
) -> typ.Self:
return ParamsFlow(
arg_targets=self.arg_targets + list(enclosing_arg_targets),
@ -269,13 +299,41 @@ class ParamsFlow:
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols,
is_differentiable=(
self.is_differentiable
if not enclosing_symbols
else (self.is_differentiable & enclosing_is_differentiable)
),
realized_symbols=self.realized_symbols,
)
def realize_partial(
self,
symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
],
) -> typ.Self:
"""Provide a particular expression/range/array to realize some symbols.
Essentially removes symbols from `self.symbols`, and adds the symbol w/value to `self.realized_symbols`.
As a result, only the still-unrealized symbols need to be passed at the time of realization (using ex. `self.scaled_func_args()`).
Parameters:
symbol_values: The value to realize for each `SimSymbol`.
**All keys** must be identically matched to a single element of `self.symbol`.
Can be empty, in which case an identical new `ParamsFlow` will be returned.
Raises:
ValueError: If any symbol in `symbol_values`
"""
syms = set(symbol_values.keys())
if syms.issubset(self.symbols) or not syms:
return ParamsFlow(
arg_targets=self.arg_targets,
kwarg_targets=self.kwarg_targets,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
symbols=self.symbols - syms,
realized_symbols=self.realized_symbols | symbol_values,
)
msg = f'ParamsFlow: Not all partially realized symbols are defined on the ParamsFlow (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
####################
# - Generate ExprSocketDef
####################

View File

@ -16,7 +16,7 @@
"""Declares `ManagedBLImage`."""
# import time
import time
import typing as typ
import bpy
@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj):
dpi: int | None = None,
bl_select: bool = False,
):
# times = [time.perf_counter()]
times = ['START', time.perf_counter()]
# Compute Plot Dimensions
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
@ -277,22 +277,22 @@ class ManagedBLImage(base.ManagedObj):
# _width_inches, _height_inches, _dpi
# )
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
# fig.clear()
ax.clear()
# times.append(['Clear Axis', time.perf_counter() - times[0]])
times.append(['Clear Axis', time.perf_counter() - times[0]])
# Plot w/User Parameter
func_plotter(ax)
# times.append(['Plot!', time.perf_counter() - times[0]])
times.append(['Plot!', time.perf_counter() - times[0]])
# Save Figure to BytesIO
canvas.draw()
# times.append(['Draw Pixels', time.perf_counter() - times[0]])
times.append(['Draw Pixels', time.perf_counter() - times[0]])
canvas_width_px, canvas_height_px = fig.canvas.get_width_height()
# times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
image_data = (
np.float32(
np.flipud(
@ -303,7 +303,7 @@ class ManagedBLImage(base.ManagedObj):
)
/ 255
)
# times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
# Optimized Write to Blender Image
bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8')

View File

@ -15,6 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import functools
import queue
import typing as typ
import bpy
@ -26,6 +28,14 @@ from .managed_objs.managed_bl_image import ManagedBLImage
log = logger.get(__name__)
link_action_queue = queue.Queue()
def set_link_validity(link: bpy.types.NodeLink, validity: bool) -> None:
log.critical('Set %s validity to %s', str(link), str(validity))
link.is_valid = validity
####################
# - Cache Management
####################
@ -45,58 +55,93 @@ class DeltaNodeLinkCache(typ.TypedDict):
class NodeLinkCache:
"""A pointer-based cache of node links in a node tree.
"""A volatile pointer-based cache of node links in a node tree.
Warnings:
Everything here is **extremely** unsafe.
Even a single mistake **will** cause a use-after-free crash of Blender.
Used perfectly, it allows for powerful features; anything less, and it's an epic liability.
Attributes:
_node_tree: Reference to the owning node tree.
link_ptrs_as_links:
link_ptrs: Pointers (as in integer memory adresses) to `NodeLink`s.
link_ptrs_as_links: Map from pointers to actual `NodeLink`s.
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the source of each `NodeLink`.
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the destination of each `NodeLink`.
_node_tree: Reference to the node tree for which this cache is valid.
link_ptrs: Memory-address identifiers for all node links that currently exist in `_node_tree`.
link_ptrs_as_links: Mapping from pointers (integers) to actual `NodeLink` objects.
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
socket_ptrs: Memory-address identifiers for all sockets that currently exist in `_node_tree`.
socket_ptrs_as_sockets: Mapping from pointers (integers) to actual `NodeSocket` objects.
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
socket_ptr_refcount: The amount of links currently connected to a given socket pointer.
Used to drive the deletion of socket pointers using only knowledge about `link_ptr` removal.
link_ptrs_as_from_socket_ptrs: The pointer of the source socket, defined for every node link pointer.
link_ptrs_as_to_socket_ptrs: The pointer of the destination socket, defined for every node link pointer.
"""
def __init__(self, node_tree: bpy.types.NodeTree):
"""Initialize the cache from a node tree.
Parameters:
node_tree: The Blender node tree whose `NodeLink`s will be cached.
"""
"""Defines and fills the cache from a live node tree."""
self._node_tree = node_tree
# Link PTR and PTR->REF
self.link_ptrs: set[MemAddr] = set()
self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {}
# Socket PTR and PTR->REF
self.socket_ptrs: set[MemAddr] = set()
self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {}
self.socket_ptr_refcount: dict[MemAddr, int] = {}
# Link PTR -> Socket PTR
self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {}
self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {}
self.link_ptrs_invalid: set[MemAddr] = set()
# Fill Cache
self.regenerate()
def remove_link(self, link_ptr: MemAddr) -> None:
"""Removes a link pointer from the cache, indicating that the link doesn't exist anymore.
"""Reports a link as removed, causing it to be removed from the cache.
This **must** be run whenever a node link is deleted.
**Failure to to so WILL result in segmentation fault** at an unknown future time.
In particular, the following actions are taken:
- The entry in `self.link_ptrs_as_links` is deleted.
- Any entry in `self.link_ptrs_invalid` is deleted (if exists).
Notes:
- **DOES NOT** remove PTR->REF dictionary entries
- Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
- This **must** be done whenever a node link is deleted.
- Failure to do so may result in a segmentation fault at arbitrary future time.
Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
In some cases, this may be desirable, ex. for internal methods that shouldn't trip a `DataChanged` flow event.
Parameters:
link_ptr: Pointer to remove from the cache.
link_ptr: The pointer (integer) to remove from the cache.
Raises:
KeyError: If `link_ptr` is not a member of either `self.link_ptrs`, or of `self.link_ptrs_as_links`.
"""
self.link_ptrs.remove(link_ptr)
self.link_ptrs_as_links.pop(link_ptr)
if link_ptr in self.link_ptrs_invalid:
self.link_ptrs_invalid.remove(link_ptr)
def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None:
"""Removes a single pointer's reference to its from/to sockets."""
"""Deassociate from all sockets referenced by a link, respecting the socket pointer reference-count.
The `NodeLinkCache` stores references to all socket pointers referenced by any link.
Since several links can be associated with each socket, we must keep a "reference count" per-socket.
When the "reference count" drops to zero, then there are no longer any `NodeLink`s that refer to it, and therefore it should be removed from the `NodeLinkCache`.
This method facilitates that process by:
- Extracting (with removal) the from / to socket pointers associated with `link_ptr`.
- If the socket pointer has a reference count of `1`, then it is **completely removed**.
- If the socket pointer has a reference count of `>1`, then the reference count is decremented by `1`.
Notes:
In general, this should be called together with `remove_link`.
However, in certain cases, this process also needs to happen by itself.
Parameters:
link_ptr: The pointer (integer) to remove from the cache.
"""
# Remove Socket Pointers
from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None)
to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None)
@ -113,31 +158,40 @@ class NodeLinkCache:
self.socket_ptr_refcount[socket_ptr] -= 1
def regenerate(self) -> DeltaNodeLinkCache:
"""Regenerates the cache from the internally-linked node tree.
"""Efficiently scans the internally referenced node tree to thoroughly update all attributes of this `NodeLinkCache`.
Notes:
- This is designed to run within the `update()` invocation of the node tree.
- This should be a very fast function, since it is called so much.
This runs in a **very** hot loop, within the `update()` function of the node tree.
Anytime anything happens in the node tree, `update()` (and therefore this method) is called.
Thus, performance is of the utmost importance.
Just a few microseconds too much may be amplified dozens of times over in practice, causing big stutters.
"""
# Compute All NodeLink Pointers
## -> It can be very inefficient to do any full-scan of the node tree.
## -> However, simply extracting the pointer: link ends up being fast.
## -> This pattern seems to be the best we can do, efficiency-wise.
all_link_ptrs_as_links = {
link.as_pointer(): link for link in self._node_tree.links
}
all_link_ptrs = set(all_link_ptrs_as_links.keys())
# Compute Added/Removed Links
## -> In essence, we've created a 'diff' here.
## -> Set operations are fast, and expressive!
added_link_ptrs = all_link_ptrs - self.link_ptrs
removed_link_ptrs = self.link_ptrs - all_link_ptrs
# Edge Case: 'from_socket' Reassignment
## (Reverse engineered) When all:
## - Created a new link between the same two nodes.
## - Matching 'to_socket'.
## - Non-matching 'from_socket' on the same node.
## -> THEN the link_ptr will not change, but the from_socket ptr should.
if len(added_link_ptrs) == 0 and len(removed_link_ptrs) == 0:
## (Reverse Engineered) When all are true:
## - Created a new link between the same nodes as previous link.
## - Matching 'to_socket' as the previous link.
## - Non-matching 'from_socket', but on the same node.
## -> THEN the link_ptr will not change, but the from_socket ptr does.
if not added_link_ptrs and not removed_link_ptrs:
# Find the Link w/Reassigned 'from_socket' PTR
## A bit of a performance hit from the search, but it's an edge case.
## -> This isn't very fast, but the edge case isn't so common.
## -> Comprehensions are still quite optimized.
_link_ptr_as_from_socket_ptrs = {
link_ptr: (
from_socket_ptr,
@ -149,9 +203,9 @@ class NodeLinkCache:
}
# Completely Remove the Old Link (w/Reassigned 'from_socket')
## This effectively reclassifies the edge case as a normal 're-add'.
## -> Casts the edge case to look like a typical 're-add'.
for link_ptr in _link_ptr_as_from_socket_ptrs:
log.info(
log.debug(
'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s',
link_ptr,
)
@ -159,21 +213,25 @@ class NodeLinkCache:
self.remove_sockets_by_link_ptr(link_ptr)
# Recompute Added/Removed Links
## The algorithm will now detect an "added link".
## -> Guide the usual algorithm to detect an "added link".
added_link_ptrs = all_link_ptrs - self.link_ptrs
removed_link_ptrs = self.link_ptrs - all_link_ptrs
# Shuffle Cache based on Change in Links
## Remove Entries for Removed Pointers
# Delete Removed Links
## -> NOTE: We leave dangling socket information on purpose.
## -> This information will be used to ask for 'removal consent'.
## -> To truly remove, must call 'remove_socket_by_link_ptr' later.
for removed_link_ptr in removed_link_ptrs:
self.remove_link(removed_link_ptr)
## User must manually call 'remove_socket_by_link_ptr' later.
## For now, leave dangling socket information by-link.
# Add New Link Pointers
# Create Added Links
## -> First, simply concatenate the added link pointers.
self.link_ptrs |= added_link_ptrs
for link_ptr in added_link_ptrs:
# Add Link PTR->REF
# Create Pointer -> Reference Entry
## -> This allows us to efficiently access the link by-pointer.
## -> Doing so otherwise requires a full search.
## -> **If link is deleted w/o report, access will cause crash**.
new_link = all_link_ptrs_as_links[link_ptr]
self.link_ptrs_as_links[link_ptr] = new_link
@ -183,34 +241,69 @@ class NodeLinkCache:
to_socket = new_link.to_socket
to_socket_ptr = to_socket.as_pointer()
# Add Socket PTR, PTR -> REF
# Add Socket Information
for socket_ptr, bl_socket in zip( # noqa: B905
[from_socket_ptr, to_socket_ptr],
[from_socket, to_socket],
):
# Increment RefCount of Socket PTR
# RefCount > 0: Increment RefCount of Socket PTR
## This happens if another link also uses the same socket.
## 1. An output socket links to several inputs.
## 2. A multi-input socket links from several inputs.
if socket_ptr in self.socket_ptr_refcount:
self.socket_ptr_refcount[socket_ptr] += 1
# RefCount == 0: Create Socket Pointer w/Reference
## -> Also initialize the refcount for the socket pointer.
else:
## RefCount == 0: Add PTR, PTR -> REF
self.socket_ptrs.add(socket_ptr)
self.socket_ptrs_as_sockets[socket_ptr] = bl_socket
self.socket_ptr_refcount[socket_ptr] = 1
# Add Link PTR -> Socket PTR
# Add Entry from Link Pointer -> Socket Pointer
self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr
self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr
return {'added': added_link_ptrs, 'removed': removed_link_ptrs}
def update_validity(self) -> DeltaNodeLinkCache:
"""Query all cached links to determine whether they are valid."""
self.link_ptrs_invalid = {
link_ptr for link_ptr, link in self.link_ptrs_as_links if not link.is_valid
}
def report_validity(self, link_ptr: MemAddr, validity: bool) -> None:
"""Report a link as invalid."""
if validity and link_ptr in self.link_ptrs_invalid:
self.link_ptrs_invalid.remove(link_ptr)
elif not validity and link_ptr not in self.link_ptrs_invalid:
self.link_ptrs_invalid.add(link_ptr)
def set_validities(self) -> None:
"""Set the validity of links in the node tree according to the internal cache.
Validity doesn't need to be removed, as update() automatically cleans up by default.
"""
for link in [
link
for link_ptr, link in self.link_ptrs_as_links.items()
if link_ptr in self.link_ptrs_invalid
]:
if link.is_valid:
link.is_valid = False
####################
# - Node Tree Definition
####################
class MaxwellSimTree(bpy.types.NodeTree):
"""Node tree containing a node-based program for design and analysis of Maxwell PDE simulations.
Attributes:
is_active: Whether the node tree should be considered to be in a usable state, capable of updating Blender data.
In general, only one `MaxwellSimTree` should be active at a time.
"""
bl_idname = ct.TreeType.MaxwellSim.value
bl_label = 'Maxwell Sim Editor'
bl_icon = ct.Icon.SimNodeEditor
@ -219,63 +312,6 @@ class MaxwellSimTree(bpy.types.NodeTree):
default=True,
)
####################
# - Lock Methods
####################
def unlock_all(self) -> None:
"""Unlock all nodes in the node tree, making them editable."""
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
for node in self.nodes:
if node.type in ['REROUTE', 'FRAME']:
continue
node.locked = False
for bl_socket in [*node.inputs, *node.outputs]:
bl_socket.locked = False
@contextlib.contextmanager
def replot(self) -> None:
self.is_currently_replotting = True
self.something_plotted = False
try:
yield
finally:
self.is_currently_replotting = False
if not self.something_plotted:
ManagedBLImage.hide_preview()
def report_show_plot(self, node: bpy.types.Node) -> None:
if hasattr(self, 'is_currently_replotting') and self.is_currently_replotting:
self.something_plotted = True
@contextlib.contextmanager
def repreview_all(self) -> None:
all_nodes_with_preview_active = {
node.instance_id: node
for node in self.nodes
if node.type not in ['REROUTE', 'FRAME'] and node.preview_active
}
self.is_currently_repreviewing = True
self.newly_previewed_nodes = {}
try:
yield
finally:
self.is_currently_repreviewing = False
for dangling_previewed_node in [
node
for node_instance_id, node in all_nodes_with_preview_active.items()
if node_instance_id not in self.newly_previewed_nodes
]:
dangling_previewed_node.preview_active = False
def report_show_preview(self, node: bpy.types.Node) -> None:
if (
hasattr(self, 'is_currently_repreviewing')
and self.is_currently_repreviewing
):
self.newly_previewed_nodes[node.instance_id] = node
####################
# - Init Methods
####################
@ -290,7 +326,54 @@ class MaxwellSimTree(bpy.types.NodeTree):
self.node_link_cache = NodeLinkCache(self)
####################
# - Update Methods
# - Lock Methods
####################
def unlock_all(self) -> None:
"""Unlock all nodes in the node tree, making them editable.
Notes:
All `MaxwellSimNode`s have a `.locked` attribute, which prevents the entire UI from being modified.
This method simply sets the `locked` attribute to `False` on all nodes.
"""
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
for node in self.nodes:
if node.type in ['REROUTE', 'FRAME']:
continue
# Unlock Node
if node.locked:
node.locked = False
# Unlock Node Sockets
for bl_socket in [*node.inputs, *node.outputs]:
if bl_socket.locked:
bl_socket.locked = False
####################
# - Link Update Methods
####################
def report_link_validity(self, link: bpy.types.NodeLink, validity: bool) -> None:
"""Report that a particular `NodeLink` should be considered to be either valid or invalid.
The `NodeLink.is_valid` attribute is generally (and automatically) used to indicate the detection of cycles in the node tree.
However, visually, it causes a very clear "error red" highlight to appear on the node link, which can extremely useful when determining the reasons behind unexpected outout.
Notes:
Run by `MaxwellSimSocket` when a link should be shown to be "invalid".
"""
## TODO: Doesn't quite work.
# log.debug(
# 'Reported Link Validity %s (is_valid=%s, from_socket=%s, to_socket=%s)',
# validity,
# link.is_valid,
# link.from_socket,
# link.to_socket,
# )
# self.node_link_cache.report_validity(link.as_pointer(), validity)
####################
# - Node Update Methods
####################
def on_node_removed(self, node: bpy.types.Node):
"""Run by `MaxwellSimNode.free()` when a node is being removed.
@ -327,32 +410,36 @@ class MaxwellSimTree(bpy.types.NodeTree):
self.node_link_cache.remove_link(link_ptr)
self.node_link_cache.remove_sockets_by_link_ptr(link_ptr)
def update(self) -> None:
def update(self) -> None: # noqa: PLR0912, C901
"""Monitors all changes to the node tree, potentially responding with appropriate callbacks.
Notes:
- Run by Blender when "anything" changes in the node tree.
- Responds to node link changes with callbacks, with the help of a performant node link cache.
"""
# Perform Initial Load
## -> Presume update() is run before the first link is altered.
## -> Else, the first link of the session will not update caches.
## -> We still remain slightly unsure of the exact semantics.
## -> Therefore, self.on_load() is also called as a load_post handler.
if not hasattr(self, 'node_link_cache'):
self.on_load()
return
# Register Validity Updater
## -> They will be run after the update() method.
## -> Between update() and set_validities, all is_valid=True are cleared.
## -> Therefore, 'set_validities' only needs to set all is_valid=False.
bpy.app.timers.register(self.node_link_cache.set_validities)
# Ignore Updates
## -> Certain corrective processes require suppressing the next update.
## -> Otherwise, link corrections may trigger some nasty recursions.
if not hasattr(self, 'ignore_update'):
self.ignore_update = False
if not hasattr(self, 'node_link_cache'):
self.on_load()
## We presume update() is run before the first link is altered.
## - Else, the first link of the session will not update caches.
## - We remain slightly unsure of the semantics.
## - Therefore, self.on_load() is also called as a load_post handler.
return
# Ignore Update
## Manually set to implement link corrections w/o recursion.
if self.ignore_update:
return
# Compute Changes to Node Links
# Regenerate NodeLinkCache
delta_links = self.node_link_cache.regenerate()
link_corrections = {
'to_remove': [],
'to_add': [],

View File

@ -358,6 +358,11 @@ class ExtractDataNode(base.MaxwellSimNode):
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
# Extract Info
## -> We only need the output symbol.
## -> All labelled outputs have the same output SimSymbol.
info = extract_info(monitor_data, idx_labels[0])
# Generate FuncFlow Per Index Label
## -> We extract each XArray as an attribute of monitor_data.
## -> We then bind its values into a unique func_flow.
@ -377,7 +382,8 @@ class ExtractDataNode(base.MaxwellSimNode):
## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
enclosing_func=lambda data: jnp.stack(data, axis=0)
lambda data: jnp.stack(data, axis=0),
func_output=info.output,
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending

View File

@ -65,12 +65,12 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation
return {
# Slice
FO.Slice: '=a[i:j]',
FO.SliceIdx: '≈a[v₁:v₂]',
FO.Slice: '≈a[v₁:v₂]',
FO.SliceIdx: '=a[i:j]',
# Pin
FO.PinLen1: 'pinₐ',
FO.Pin: 'pinₐ ≈v',
FO.PinIdx: 'pinₐ =i',
FO.PinLen1: 'a[0] → a',
FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'a[i] → a',
# Reinterpret
FO.Swap: 'a₁ ↔ a₂',
}[value]
@ -517,6 +517,7 @@ class FilterMathNode(base.MaxwellSimNode):
return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args,
enclosing_func_output=info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending

View File

@ -547,22 +547,30 @@ class MapMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
# Loaded
kind=ct.FlowKind.Func,
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': ct.FlowKind.Func,
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
operation = props['operation']
def compute_func(
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
expr = input_sockets['Expr']
output_info = output_sockets['Expr']
has_expr = not ct.FlowSignal.check(expr)
has_output_info = not ct.FlowSignal.check(output_info)
operation = props['operation']
if has_expr and operation is not None:
return expr.compose_within(
operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending

View File

@ -146,7 +146,6 @@ class BinaryOperation(enum.StrEnum):
outl = info_l.output
outr = info_r.output
match (outl.shape_len, outr.shape_len):
# match (ol.shape_len, info_r.output.shape_len):
# Number | *
## Number | Number
case (0, 0):
@ -154,15 +153,25 @@ class BinaryOperation(enum.StrEnum):
BO.Add,
BO.Sub,
BO.Mul,
BO.Div,
BO.Pow,
]
# Check Non-Zero Right Hand Side
## -> Obviously, we can't ever divide by zero.
## -> Sympy's assumptions system must always guarantee rhs != 0.
## -> If it can't, then we simply don't expose division.
## -> The is_zero assumption must be provided elsewhere.
## -> NOTE: This may prevent some valid uses of division.
## -> Watch out for "division is missing" bugs.
if info_r.output.is_nonzero:
ops.append(BO.Div)
if (
info_l.output.physical_type == spux.PhysicalType.Length
and info_l.output.unit == info_r.output.unit
):
ops += [BO.Atan2]
return ops
return [*ops, BO.Pow]
## Number | Vector
case (0, 1):
@ -336,7 +345,13 @@ class BinaryOperation(enum.StrEnum):
# - InfoFlow Transform
####################
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
Warnings:
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
If not, bad things will happen.
"""
return info_l.operate_output(
info_r,
lambda a, b: self.sp_func([a, b]),
@ -479,29 +494,35 @@ class OperateMathNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
# Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Func,
'Expr R': ct.FlowKind.Func,
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compose_func(self, props: dict, input_sockets: dict):
def compute_func(self, props, input_sockets, output_sockets):
operation = props['operation']
if operation is None:
return ct.FlowSignal.FlowPending
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
output_info = output_sockets['Expr']
has_expr_l = not ct.FlowSignal.check(expr_l)
has_expr_r = not ct.FlowSignal.check(expr_r)
has_output_info = not ct.FlowSignal.check(output_info)
# Compute Jax Function
## -> The operation enum directly provides the appropriate function.
if has_expr_l and has_expr_r:
if has_expr_l and has_expr_r and has_output_info:
return (expr_l | expr_r).compose_within(
enclosing_func=operation.jax_func,
operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -520,6 +541,8 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
BO = BinaryOperation
operation = props['operation']
info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R']
@ -533,7 +556,7 @@ class OperateMathNode(base.MaxwellSimNode):
has_info_l
and has_info_r
and operation is not None
and operation in BinaryOperation.by_infos(info_l, info_r)
and operation in BO.by_infos(info_l, info_r)
):
return operation.transform_infos(info_l, info_r)

View File

@ -606,27 +606,32 @@ class TransformMathNode(base.MaxwellSimNode):
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
def compute_func(
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = TransformOperation
operation = props['operation']
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
output_info = output_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_output_info = not ct.FlowSignal.check(output_info)
if operation is not None and has_lazy_func and has_info:
# Retrieve Properties
operation = props['operation']
if operation is not None and has_lazy_func and has_info and has_output_info:
dim = props['dim']
# Match Pattern by Operation
match operation:
case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D:
if dim is not None and info.has_idx_discrete(dim):
return lazy_func.compose_within(
operation.jax_func(axis=info.dim_axis(dim)),
enclosing_func_output=output_info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -634,6 +639,7 @@ class TransformMathNode(base.MaxwellSimNode):
case _:
return lazy_func.compose_within(
operation.jax_func(),
enclosing_func_output=output_info.output,
supports_jax=True,
)

View File

@ -406,7 +406,7 @@ class VizNode(base.MaxwellSimNode):
},
all_loose_input_sockets=True,
)
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
def compute_previews(self, props, input_sockets, loose_input_sockets):
"""Needed for the plot to regenerate in the viewer."""
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
@ -433,7 +433,7 @@ class VizNode(base.MaxwellSimNode):
def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets
) -> None:
log.critical('Show Plot (too many times)')
log.debug('Show Plot')
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
@ -456,6 +456,7 @@ class VizNode(base.MaxwellSimNode):
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
},
)
## TODO: CACHE entries that don't change, PLEASEEE
# Match Viz Type & Perform Visualization
## -> Viz Target determines how to plot.

View File

@ -207,12 +207,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
stop_propagation=True,
)
def _on_sim_node_name_changed(self, props):
log.debug(
'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
self.bl_idname,
props['sim_node_name'],
str(self),
)
# log.debug(
# 'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
# self.bl_idname,
# props['sim_node_name'],
# str(self),
# )
# (Re)Construct Managed Objects
## -> Due to 'prev_name', the new MObjs will be renamed on construction
@ -360,27 +360,48 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
####################
# - Socket Management
####################
## TODO: Check for namespace collisions in sockets to prevent silent errors
def _prune_inactive_sockets(self):
"""Remove all "inactive" sockets from the node.
"""Remove all inactive sockets from the node, while only updating sockets that can be non-destructively updated.
A socket is considered "inactive" when it shouldn't be defined (per `self.active_socket_defs), but is present nonetheless.
The first step is easy: We determine, by-name, which sockets should no longer be defined, then remove them correctly.
The second step is harder: When new sockets have overlapping names, should they be removed, or should they merely have some properties updated?
Removing and re-adding the same socket is an accurate, generally robust approach, but it comes with a big caveat: **Existing node links will be cut**, even when it might semantically make sense to simply alter the socket's properties, keeping the links.
Different `bl_socket.socket_type`s can never be updated - they must be removed.
Otherwise, `SocketDef.compare(bl_socket)` allows us to granularly determine whether a particular `bl_socket` has changed with respect to the desired specification.
When the comparison is `False`, we can carefully utilize `SocketDef.init()` to re-initialize the socket, guaranteeing that the altered socket is up to the new specification.
"""
node_tree = self.id_data
for direc in ['input', 'output']:
all_bl_sockets = self._bl_sockets(direc)
active_bl_socket_defs = self.active_socket_defs(direc)
bl_sockets = self._bl_sockets(direc)
active_socket_defs = self.active_socket_defs(direc)
# Determine Sockets to Remove
## -> Name: If the existing socket name isn't "active".
## -> Type: If the existing socket_type != "active" SocketDef.
bl_sockets_to_remove = [
bl_socket
for socket_name, bl_socket in all_bl_sockets.items()
if socket_name not in active_bl_socket_defs
or socket_name
in (
self.loose_input_sockets
if direc == 'input'
else self.loose_output_sockets
for socket_name, bl_socket in bl_sockets.items()
if (
socket_name not in active_socket_defs
or bl_socket.socket_type
is not active_socket_defs[socket_name].socket_type
)
]
# Determine Sockets to Update
## -> Name: If the existing socket name is "active".
## -> Type: If the existing socket_type == "active" SocketDef.
## -> Compare: If the existing socket differs from the SocketDef.
bl_sockets_to_update = [
bl_socket
for socket_name, bl_socket in bl_sockets.items()
if (
socket_name in active_socket_defs
and bl_socket.socket_type
is active_socket_defs[socket_name].socket_type
and not active_socket_defs[socket_name].compare(bl_socket)
)
]
@ -392,24 +413,25 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> The NodeLinkCache needs to be adjusted manually.
node_tree.on_node_socket_removed(bl_socket)
# 2. Invalidate the input socket cache across all kinds.
# 2. Perform the removal using Blender's API.
## -> Actually removes the socket.
bl_sockets.remove(bl_socket)
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self._compute_input.invalidate(
input_socket_name=bl_socket_name,
kind=...,
unit_system=...,
)
# 3. Perform the removal using Blender's API.
## -> Actually removes the socket.
all_bl_sockets.remove(bl_socket)
if direc == 'input':
# 4. Run all trigger-only `on_value_changed` callbacks.
## -> Runs any event methods that relied on the socket.
## -> Only methods that don't **require** the socket.
## Trigger-Only: If method loads no socket data, it runs.
## `optional`: If method optional-loads socket, it runs.
## Only Trigger: If method loads no socket data, it runs.
## Optional: If method optional-loads socket, it runs.
triggered_event_methods = [
event_method
for event_method in self.filtered_event_methods_by_event(
@ -419,32 +441,52 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
not in event_method.callback_info.must_load_sockets
]
for event_method in triggered_event_methods:
log.critical(
'%s: Running %s',
self.sim_node_name,
str(event_method),
)
event_method(self)
# Update Sockets
for bl_socket in bl_sockets_to_update:
bl_socket_name = bl_socket.name
socket_def = active_socket_defs[bl_socket_name]
# 1. Pretend to Initialize for the First Time
## -> NOTE: The socket's caches will be completely regenerated.
## -> NOTE: A full FlowKind update will occur, but only one.
bl_socket.is_initializing = True
socket_def.preinit(bl_socket)
socket_def.init(bl_socket)
socket_def.postinit(bl_socket)
# 2. Re-Test Socket Capabilities
## -> Factors influencing CapabilitiesFlow may have changed.
## -> Therefore, we must re-test all link capabilities.
bl_socket.remove_invalidated_links()
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
self._compute_input.invalidate(
input_socket_name=bl_socket_name,
kind=...,
unit_system=...,
)
def _add_new_active_sockets(self):
"""Add and initialize all "active" sockets that aren't on the node.
Existing sockets within the given direction are not re-created.
"""
for direc in ['input', 'output']:
all_bl_sockets = self._bl_sockets(direc)
active_bl_socket_defs = self.active_socket_defs(direc)
bl_sockets = self._bl_sockets(direc)
active_socket_defs = self.active_socket_defs(direc)
# Define BL Sockets
created_sockets = {}
for socket_name, socket_def in active_bl_socket_defs.items():
for socket_name, socket_def in active_socket_defs.items():
# Skip Existing Sockets
if socket_name in all_bl_sockets:
if socket_name in bl_sockets:
continue
# Create BL Socket from Socket
## Set 'display_shape' from 'socket_shape'
all_bl_sockets.new(
bl_sockets.new(
str(socket_def.socket_type.value),
socket_name,
)
@ -454,9 +496,9 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# Initialize Just-Created BL Sockets
for bl_socket_name, socket_def in created_sockets.items():
socket_def.preinit(all_bl_sockets[bl_socket_name])
socket_def.init(all_bl_sockets[bl_socket_name])
socket_def.postinit(all_bl_sockets[bl_socket_name])
socket_def.preinit(bl_sockets[bl_socket_name])
socket_def.init(bl_sockets[bl_socket_name])
socket_def.postinit(bl_sockets[bl_socket_name])
# Invalidate Cached NoFlows
self._compute_input.invalidate(
@ -637,9 +679,10 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
lambda a, b: a | b,
[
self._compute_input(
socket, kind=ct.FlowKind.Previews, unit_system=None
socket_name,
kind=ct.FlowKind.Previews,
)
for socket in [bl_socket.name for bl_socket in self.inputs]
for socket_name in [bl_socket.name for bl_socket in self.inputs]
],
ct.PreviewsFlow(),
)
@ -897,9 +940,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
)
altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
# Clear Output Socket Cache(s)
## -> We aggregate it manually, so it needs a special invl.
## -> See self.compute_output()
if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
for out_sckname in self.outputs.keys(): # noqa: SIM118
self.compute_output.invalidate(
output_socket_name=out_sckname,
kind=ct.FlowKind.Previews,
)
altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
# Run Triggered Event Methods
## -> A triggered event method may request to stop propagation.
## -> A triggered event method may request to stop propagation.
stop_propagation = False
triggered_event_methods = self.filtered_event_methods_by_event(
event, (socket_name, prop_names, None)

View File

@ -266,7 +266,7 @@ def event_decorator( # noqa: PLR0913
)
# Loose Sockets
## Compute All Loose Input Sockets
## -> Determined by the active_kind of each loose input socket.
method_kw_args |= (
{
'loose_input_sockets': {

View File

@ -29,6 +29,8 @@ from ... import base, events
class ScientificConstantNode(base.MaxwellSimNode):
"""A well-known constant usable as itself, or as a symbol."""
node_type = ct.NodeType.ScientificConstant
bl_label = 'Scientific Constant'
@ -88,6 +90,11 @@ class ScientificConstantNode(base.MaxwellSimNode):
####################
# - UI
####################
def draw_label(self):
if self.sci_constant_str:
return f'Const: {self.sci_constant_str}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
col.prop(self, self.blfields['sci_constant_str'], text='')
@ -156,6 +163,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
props={'sci_constant', 'sci_constant_sym'},
)
def compute_lazy_func(self, props) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
@ -165,6 +173,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
),
func_args=[sci_constant_sym],
func_output=sci_constant_sym,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -175,6 +184,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
props={'sci_constant_sym'},
)
def compute_info(self, props: dict) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None:
@ -193,8 +203,12 @@ class ScientificConstantNode(base.MaxwellSimNode):
if sci_constant is not None and sci_constant_sym is not None:
return ct.ParamsFlow(
arg_targets=[sci_constant_sym],
func_args=[sci_constant],
is_differentiable=True,
func_args=[sci_constant_sym.sp_symbol],
symbols={sci_constant_sym},
).realize_partial(
{
sci_constant_sym: sci_constant,
}
)
return ct.FlowSignal.FlowPending

View File

@ -216,10 +216,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
props={'symbol'},
)
def compute_lazy_func(self, props) -> typ.Any:
sp_sym = props['symbol'].sp_symbol
sym = props['symbol']
return ct.FuncFlow(
func=sp.lambdify(sp_sym, sp_sym, 'jax'),
func_args=[sp_sym],
func=sp.lambdify(sym.sp_symbol_matsym, sym.sp_symbol_matsym, 'jax'),
func_args=[sym],
func_output=sym,
supports_jax=True,
)
@ -235,6 +236,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
)
def compute_info(self, props) -> typ.Any:
return ct.InfoFlow(
dims={props['symbol']: None},
output=props['symbol'],
)
@ -251,9 +253,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
arg_targets=[sym],
func_args=[sym.sp_symbol],
symbols={sym},
is_differentiable=(
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
),
)

View File

@ -198,9 +198,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
'Expr',
kind=ct.FlowKind.Func,
# Loaded
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
input_sockets={'File Path'},
)
def compute_func(self, input_sockets) -> td.Simulation:
def compute_func(self, props, input_sockets) -> td.Simulation:
"""Declare a lazy, composable function that returns the loaded data.
Returns:
@ -209,6 +210,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
file_path = input_sockets['File Path']
has_file_path = not ct.FlowSignal.check(file_path)
func_output = sim_symbols.SimSymbol(
sym_name=props['output_name'],
mathtype=props['output_mathtype'],
physical_type=props['output_physical_type'],
unit=props['output_unit'],
)
if has_file_path and file_path is not None:
data_file_format = ct.DataFileFormat.from_path(file_path)
if data_file_format is not None:
@ -217,13 +224,18 @@ class DataFileImporterNode(base.MaxwellSimNode):
if data_file_format.loader_is_jax_compatible:
return ct.FuncFlow(
func=lambda: data_file_format.loader(file_path),
func_output=func_output,
supports_jax=True,
)
# No Jax Compatibility: Eager Data Loading
## -> Load the data now and bind it.
data = data_file_format.loader(file_path)
return ct.FuncFlow(func=lambda: data, supports_jax=True)
return ct.FuncFlow(
func=lambda: data,
func_output=func_output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending

View File

@ -86,22 +86,45 @@ class ViewerNode(base.MaxwellSimNode):
# - Properties: Computed FlowKinds
####################
@events.on_value_changed(
socket_name='Any',
# Trigger
prop_name='console_print_kind',
# Loaded
props={'auto_expr', 'console_print_kind'},
)
def on_input_changed(self) -> None:
def on_print_kind_changed(self, props) -> None:
self.inputs['Any'].active_kind = props['console_print_kind']
if props['auto_expr']:
setattr(
self,
'input_' + props['console_print_kind'].property_name,
bl_cache.Signal.InvalidateCache,
)
@events.on_value_changed(
# Trugger
socket_name='Any',
# Loaded
props={'auto_expr', 'console_print_kind'},
)
def on_input_changed(self, props) -> None:
"""Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed.
Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s.
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
"""
for flow_kind in list(ct.FlowKind):
flow = self.get_flow(
flow_kind, always_load=flow_kind is ct.FlowKind.Previews
)
if flow is not None:
# Invalidate PreviewsFlow
setattr(
self,
'input_' + flow_kind.property_name,
'input_' + ct.FlowKind.Previews.property_name,
bl_cache.Signal.InvalidateCache,
)
# Invalidate PreviewsFlow
if props['auto_expr']:
setattr(
self,
'input_' + props['console_print_kind'].property_name,
bl_cache.Signal.InvalidateCache,
)

View File

@ -14,8 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import functools
import typing as typ
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache
@ -26,6 +28,8 @@ from .. import base, events
class CombineNode(base.MaxwellSimNode):
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
node_type = ct.NodeType.Combine
bl_label = 'Combine'
@ -33,112 +37,222 @@ class CombineNode(base.MaxwellSimNode):
# - Sockets
####################
input_socket_sets: typ.ClassVar = {
'Maxwell Sources': {},
'Maxwell Structures': {},
'Maxwell Monitors': {},
'Sources': {},
'Structures': {},
'Monitors': {},
}
output_socket_sets: typ.ClassVar = {
'Maxwell Sources': {
'Sources': {
'Sources': sockets.MaxwellSourceSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
},
'Maxwell Structures': {
'Structures': {
'Structures': sockets.MaxwellStructureSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
},
'Maxwell Monitors': {
'Monitors': {
'Monitors': sockets.MaxwellMonitorSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
},
}
####################
# - Draw
# - Properties
####################
amount: int = bl_cache.BLField(2, abs_min=1, prop_ui=True)
concatenate_first: bool = bl_cache.BLField(False)
value_or_func: ct.FlowKind = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_func(),
)
def _value_or_func(self):
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
]
####################
# - Draw
####################
def draw_props(self, context, layout):
layout.prop(self, self.blfields['amount'], text='')
def draw_props(self, _, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['value_or_func'], text='')
if self.value_or_func is ct.FlowKind.Value:
layout.prop(
self,
self.blfields['concatenate_first'],
text='Concatenate',
toggle=True,
)
####################
# - Events
####################
@events.on_value_changed(
# Trigger
prop_name={'active_socket_set', 'amount'},
props={'active_socket_set', 'amount'},
any_loose_input_socket=True,
prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
run_on_init=True,
# Loaded
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
)
def on_inputs_changed(self, props):
if props['active_socket_set'] == 'Maxwell Sources':
if (
not self.loose_input_sockets
or not next(iter(self.loose_input_sockets)).startswith('Source')
or len(self.loose_input_sockets) != props['amount']
):
self.loose_input_sockets = {
f'Source #{i}': sockets.MaxwellSourceSocketDef()
for i in range(props['amount'])
}
def on_inputs_changed(self, props) -> None:
"""Always create one extra loose input socket."""
active_socket_set = props['active_socket_set']
# Deduce SocketDef
## -> Cheat by retrieving the class from the output sockets.
SocketDef = self.output_socket_sets[active_socket_set][
active_socket_set
].__class__
# Deduce Current "Filled"
## -> The first linked socket from the end bounds the "filled" region.
## -> The length of that region, plus one, will be the new amount.
reverse_linked_idxs = [
i
for i, bl_socket in enumerate(reversed(self.inputs.values()))
if bl_socket.is_linked
]
current_filled = len(self.inputs) - (
reverse_linked_idxs[0] if reverse_linked_idxs else len(self.inputs)
)
new_amount = current_filled + 1
# Deduce SocketDef | Current Amount
concatenate_first = props['concatenate_first']
flow_kind = props['value_or_func']
elif props['active_socket_set'] == 'Maxwell Structures':
if (
not self.loose_input_sockets
or not next(iter(self.loose_input_sockets)).startswith('Structure')
or len(self.loose_input_sockets) != props['amount']
):
self.loose_input_sockets = {
f'Structure #{i}': sockets.MaxwellStructureSocketDef()
for i in range(props['amount'])
}
elif props['active_socket_set'] == 'Maxwell Monitors':
if (
not self.loose_input_sockets
or not next(iter(self.loose_input_sockets)).startswith('Monitor')
or len(self.loose_input_sockets) != props['amount']
):
self.loose_input_sockets = {
f'Monitor #{i}': sockets.MaxwellMonitorSocketDef()
for i in range(props['amount'])
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
'#0': SocketDef(
active_kind=flow_kind
if flow_kind is ct.FlowKind.Func or not concatenate_first
else ct.FlowKind.Array
)
} | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
####################
# - Output Socket Computation
# - FlowKind.Array|Func
####################
def compute_combined(
self,
loose_input_sockets,
input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
If there is no output, or the flows aren't compatible, return `FlowPending`.
"""
match (input_flow_kind, output_flow_kind):
case (ct.FlowKind.Value, ct.FlowKind.Array):
value_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
]
if value_flows:
return value_flows
return ct.FlowSignal.FlowPending
case (ct.FlowKind.Func, ct.FlowKind.Func):
func_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
]
if func_flows:
return functools.reduce(
lambda a, b: a | b,
func_flows,
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
####################
# - Output: Sources
####################
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Array,
all_loose_input_sockets=True,
props={'amount'},
props={'value_or_func'},
)
def compute_sources_array(
self, props, loose_input_sockets
) -> list[typ.Any] | ct.FlowSignal:
"""Compute sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
)
def compute_sources(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Source #{i}'] for i in range(props['amount'])]
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
)
####################
# - Output: Structures
####################
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Array,
all_loose_input_sockets=True,
props={'amount'},
props={'value_or_func'},
)
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
)
def compute_structures(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Structure #{i}'] for i in range(props['amount'])]
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
)
####################
# - Output: Monitors
####################
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Array,
all_loose_input_sockets=True,
props={'amount'},
props={'value_or_func'},
)
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
)
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
)
def compute_monitors(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Monitor #{i}'] for i in range(props['amount'])]
####################

View File

@ -14,17 +14,26 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `FDTDSimNode`."""
import typing as typ
import sympy as sp
import bpy
import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
class FDTDSimNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
node_type = ct.NodeType.FDTDSim
bl_label = 'FDTD Simulation'
@ -35,51 +44,255 @@ class FDTDSimNode(base.MaxwellSimNode):
'BCs': sockets.MaxwellBoundCondsSocketDef(),
'Domain': sockets.MaxwellSimDomainSocketDef(),
'Sources': sockets.MaxwellSourceSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
'Structures': sockets.MaxwellStructureSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
'Monitors': sockets.MaxwellMonitorSocketDef(
is_list=True,
active_kind=ct.FlowKind.Array,
),
}
output_sockets: typ.ClassVar = {
'Sim': sockets.MaxwellFDTDSimSocketDef(),
output_socket_sets: typ.ClassVar = {
'Single': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
},
'Batch': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array),
},
'Lazy': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
},
}
####################
# - Output Socket Computation
# - Properties
####################
differentiable: bool = bl_cache.BLField(False)
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(
self,
self.blfields['differentiable'],
text='Differentiable',
toggle=True,
)
####################
# - Events
####################
@events.on_value_changed(
# Trigger
socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
run_on_init=True,
# Loaded
props={'active_socket_set'},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def on_any_changed(self, props, output_sockets) -> None:
"""Create loose input sockets."""
params = output_sockets['Sim']
has_params = not ct.FlowSignal.check(params)
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
active_socket_set = props['active_socket_set']
if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols:
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
'active_kind': ct.FlowKind.Value,
'use_value_range_swapper': (
active_socket_set == 'Value'
),
}
)
)
for sym, expr_info in params.sym_expr_infos.items()
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Value,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'Sources': ct.FlowKind.Array,
'Structures': ct.FlowKind.Array,
'Domain': ct.FlowKind.Value,
'BCs': ct.FlowKind.Value,
'Monitors': ct.FlowKind.Array,
},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def compute_fdtd_sim(self, input_sockets: dict) -> sp.Expr:
if any(ct.FlowSignal.check(inp) for inp in input_sockets):
return ct.FlowSignal.FlowPending
def compute_fdtd_sim_value(
self, props, input_sockets, output_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable."""
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
bounds = input_sockets['BCs']
monitors = input_sockets['Monitors']
output_params = output_sockets['Sim']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
has_output_params = not ct.FlowSignal.check(output_params)
differentiable = props['differentiable']
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
and has_output_params
and not differentiable
):
return td.Simulation(
**sim_domain,
structures=structures,
sources=sources,
monitors=monitors,
structures=structures,
boundary_spec=bounds,
monitors=monitors,
)
## TODO: Visualize the boundary conditions on top of the sim domain
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Func,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'Sources': ct.FlowKind.Func,
'Structures': ct.FlowKind.Func,
'Monitors': ct.FlowKind.Func,
},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def compute_fdtd_sim_func(
self, props, input_sockets, output_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single simulation, given that all inputs are non-symbolic."""
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
bounds = input_sockets['BCs']
monitors = input_sockets['Monitors']
output_params = output_sockets['Sim']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
and has_output_params
):
differentiable = props['differentiable']
if differentiable:
return (
sim_domain | sources | structures | bounds | monitors
).compose_within(
enclosing_func=lambda els: tdadj.JaxSimulation(
**els[0],
sources=els[1],
structures=els[2]['static'],
input_structures=els[2]['differentiable'],
boundary_spec=els[3],
monitors=els[4]['static'],
output_monitors=els[4]['differentiable'],
),
supports_jax=True,
)
return (
sim_domain | sources | structures | bounds | monitors
).compose_within(
enclosing_func=lambda els: td.Simulation(
**els[0],
sources=els[1],
structures=els[2],
boundary_spec=els[3],
monitors=els[4],
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Params,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'Sources': ct.FlowKind.Params,
'Structures': ct.FlowKind.Params,
'Monitors': ct.FlowKind.Params,
},
)
def compute_fdtd_sim_params(
self, props, input_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single simulation, given that all inputs are non-symbolic."""
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
bounds = input_sockets['BCs']
monitors = input_sockets['Monitors']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
):
# Determine Differentiable Match
## -> 'structures' is diff when **any** are diff.
## -> 'monitors' is also diff when **any** are diff.
## -> Only parameters through diff structs can be diff'ed by.
## -> Similarly, only diff monitors will have gradients computed.
return sim_domain | sources | structures | bounds | monitors
return ct.FlowSignal.FlowPending
####################

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `SimDomainNode`."""
import typing as typ
import sympy as sp
@ -31,6 +33,8 @@ log = logger.get(__name__)
class SimDomainNode(base.MaxwellSimNode):
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
node_type = ct.NodeType.SimDomain
bl_label = 'Sim Domain'
use_sim_node_name = True
@ -69,26 +73,109 @@ class SimDomainNode(base.MaxwellSimNode):
}
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Domain'},
output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Domain'][ct.FlowKind.Func]
output_params = output_sockets['Domain'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Func,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Duration': 'Tidy3DUnits',
'Center': 'Tidy3DUnits',
'Size': 'Tidy3DUnits',
input_socket_kinds={
'Duration': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func,
'Grid': ct.FlowKind.Func,
'Ambient Medium': ct.FlowKind.Func,
},
)
def compute_domain(self, input_sockets, unit_systems) -> sp.Expr:
return {
'run_time': input_sockets['Duration'],
'center': input_sockets['Center'],
'size': input_sockets['Size'],
'grid_spec': input_sockets['Grid'],
'medium': input_sockets['Ambient Medium'],
}
def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
size = input_sockets['Size']
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return (
duration.scale_to_unit_system(ct.UNITS_TIDY3D)
| center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| grid
| medium
).compose_within(
enclosing_func=lambda els: {
'run_time': els[0],
'center': tuple(els[1].flatten()),
'size': tuple(els[2].flatten()),
'grid_spec': els[3],
'medium': els[4],
},
supports_jax=False,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
'Duration': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Size': ct.FlowKind.Params,
'Grid': ct.FlowKind.Params,
'Ambient Medium': ct.FlowKind.Params,
},
)
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
size = input_sockets['Size']
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return duration | center | size | grid | medium
return ct.FlowSignal.FlowPending
####################
# - Preview
@ -100,37 +187,39 @@ class SimDomainNode(base.MaxwellSimNode):
props={'sim_node_name'},
)
def compute_previews(self, props):
"""Mark the managed preview object for preview when `Domain` is linked to a viewer."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
## Trigger
# Trigger
socket_name={'Center', 'Size'},
run_on_init=True,
# Loaded
input_sockets={'Center', 'Size'},
managed_objs={'modifier'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={
'Center': 'BlenderUnits',
},
output_sockets={'Domain'},
output_socket_kinds={'Domain': ct.FlowKind.Params},
)
def on_input_changed(
self,
managed_objs,
input_sockets,
unit_systems,
):
def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
"""Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
output_params = output_sockets['Domain']
center = input_sockets['Center']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
'unit_system': unit_systems['BlenderUnits'],
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Size': input_sockets['Size'],
},
},
location=input_sockets['Center'],
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)

View File

@ -71,35 +71,129 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['pol_axis'], expand=True)
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Source',
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
# Loaded
props={'pol_axis'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
},
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
output_sockets={'Source'},
output_socket_kinds={'Source': ct.FlowKind.Params},
)
def compute_source(
self,
input_sockets: dict[str, typ.Any],
props: dict[str, typ.Any],
unit_systems: dict,
) -> td.PointDipole:
def compute_source_value(
self, input_sockets, props, output_sockets
) -> td.PointDipole | ct.FlowSignal:
"""Compute the point dipole source, given that all inputs are non-symbolic."""
temporal_shape = input_sockets['Temporal Shape']
center = input_sockets['Center']
interpolate = input_sockets['Interpolate']
output_params = output_sockets['Source']
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_center = not ct.FlowSignal.check(center)
has_interpolate = not ct.FlowSignal.check(interpolate)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_temporal_shape
and has_center
and has_interpolate
and has_output_params
and not output_params.symbols
):
pol_axis = {
ct.SimSpaceAxis.X: 'Ex',
ct.SimSpaceAxis.Y: 'Ey',
ct.SimSpaceAxis.Z: 'Ez',
}[props['pol_axis']]
## TODO: Need Hx, Hy, Hz too?
return td.PointDipole(
center=input_sockets['Center'],
source_time=input_sockets['Temporal Shape'],
interpolate=input_sockets['Interpolate'],
center=spux.convert_to_unit_system(center, ct.UNITS_TIDY3D),
source_time=temporal_shape,
interpolate=interpolate,
polarization=pol_axis,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Func,
# Loaded
props={'pol_axis'},
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Interpolate': ct.FlowKind.Func,
},
output_sockets={'Source'},
output_socket_kinds={'Source': ct.FlowKind.Params},
)
def compute_source_func(self, props, input_sockets, output_sockets) -> td.Box:
"""Compute a lazy function for the point dipole source."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
output_params = output_sockets['Source']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_interpolate = not ct.FlowSignal.check(interpolate)
has_output_params = not ct.FlowSignal.check(output_params)
if has_temporal_shape and has_center and has_interpolate and has_output_params:
pol_axis = {
ct.SimSpaceAxis.X: 'Ex',
ct.SimSpaceAxis.Y: 'Ey',
ct.SimSpaceAxis.Z: 'Ez',
}[props['pol_axis']]
## TODO: Need Hx, Hy, Hz too?
return (center | temporal_shape | interpolate).compose_within(
enclosing_func=lambda els: td.PointDipole(
center=els[0],
source_time=els[1],
interpolate=els[2],
polarization=pol_axis,
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Interpolate': ct.FlowKind.Params,
},
)
def compute_params(
self,
input_sockets,
) -> td.PointDipole | ct.FlowSignal:
"""Compute the point dipole source, given that all inputs are non-symbolic."""
temporal_shape = input_sockets['Temporal Shape']
center = input_sockets['Center']
interpolate = input_sockets['Interpolate']
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_center = not ct.FlowSignal.check(center)
has_interpolate = not ct.FlowSignal.check(interpolate)
if has_temporal_shape and has_center and has_interpolate:
return temporal_shape | center | interpolate
return ct.FlowSignal.FlowPending
####################
# - Preview

View File

@ -16,15 +16,19 @@
"""Implements the `TemporalShapeNode`."""
import enum
import typing as typ
import bpy
import numpy as np
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from ... import contracts as ct
from ... import managed_objs, sockets
@ -33,14 +37,10 @@ from .. import base, events
log = logger.get(__name__)
_max_e_socket_def = sockets.ExprSocketDef(
mathtype=spux.MathType.Complex,
physical_type=spux.PhysicalType.EField,
default_value=1 + 0j,
)
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
t_ps = sim_symbols.t(spu.picosecond)
# Select Default Time Unit for Envelope
## -> Chosen to align with the default envelope_time_unit.
## -> This causes it to be correct from the start.
t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
class TemporalShapeNode(base.MaxwellSimNode):
@ -63,17 +63,18 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_unit=spux.THz,
default_value=200,
),
'max E': sockets.ExprSocketDef(
mathtype=spux.MathType.Complex,
physical_type=spux.PhysicalType.EField,
default_value=1 + 0j,
),
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
}
input_socket_sets: typ.ClassVar = {
'Pulse': {
'max E': _max_e_socket_def,
'Offset Time': _offset_socket_def,
'Remove DC': sockets.BoolSocketDef(default_value=True),
},
'Constant': {
'max E': _max_e_socket_def,
'Offset Time': _offset_socket_def,
},
'Constant': {},
'Symbolic': {
't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
@ -84,8 +85,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_steps=100,
),
'Envelope': sockets.ExprSocketDef(
default_symbols=[t_ps],
default_value=10 * t_ps.sp_symbol,
default_symbols=[t_def],
default_value=10 * t_def.sp_symbol,
),
},
}
@ -98,6 +99,55 @@ class TemporalShapeNode(base.MaxwellSimNode):
'plot': managed_objs.ManagedBLImage,
}
####################
# - Properties
####################
active_envelope_time_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_time_units(),
)
def search_time_units(self) -> list[ct.BLEnumElement]:
"""Compute all valid time units."""
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
]
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
def envelope_time_unit(self) -> spux.Unit | None:
"""Gets the current active unit for the envelope time symbol.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_envelope_time_unit is not None:
return spux.unit_str_to_unit(self.active_envelope_time_unit)
return None
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
if (
self.active_socket_set == 'Symbolic'
and self.inputs.get('Envelope')
and not self.inputs['Envelope'].is_linked
):
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Envelope Time Unit')
row = layout.row()
row.prop(
self,
self.blfields['active_envelope_time_unit'],
text='',
toggle=True,
)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Symbolic':
box = layout.box()
@ -118,10 +168,53 @@ class TemporalShapeNode(base.MaxwellSimNode):
col.label(text='1 / 2π·σ(𝑓)')
####################
# - FlowKind: Value
# - Events
####################
@events.on_value_changed(
# Trigger
prop_name={'active_socket_set', 'envelope_time_unit'},
# Loaded
props={'active_socket_set', 'envelope_time_unit'},
)
def on_envelope_time_unit_changed(self, props) -> None:
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
active_socket_set = props['active_socket_set']
envelope_time_unit = props['envelope_time_unit']
if active_socket_set == 'Symbolic':
bl_socket = self.inputs['Envelope']
wanted_t_sym = sim_symbols.t(envelope_time_unit)
if not bl_socket.symbols or bl_socket.symbols[0] != wanted_t_sym:
bl_socket.symbols = [wanted_t_sym]
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Temporal Shape'},
output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute a single temporal shape."""
output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
####################
# - FlowKind: Func
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Func,
# Loaded
props={'active_socket_set'},
input_sockets={
@ -134,60 +227,178 @@ class TemporalShapeNode(base.MaxwellSimNode):
'Envelope',
},
input_socket_kinds={
't Range': ct.FlowKind.Range,
'Envelope': ct.FlowKind.Func,
},
input_sockets_optional={
'max E': True,
'Offset Time': True,
'Remove DC': True,
't Range': True,
'Envelope': True,
},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'max E': 'Tidy3DUnits',
'μ Freq': 'Tidy3DUnits',
'σ Freq': 'Tidy3DUnits',
't Range': 'Tidy3DUnits',
'Offset Time': 'Tidy3DUnits',
'max E': ct.FlowKind.Func,
'μ Freq': ct.FlowKind.Func,
'σ Freq': ct.FlowKind.Func,
'Offset Time': ct.FlowKind.Func,
'Remove DC': ct.FlowKind.Value,
't Range': ct.FlowKind.Func,
'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
},
)
def compute_temporal_shape(
self, props, input_sockets, unit_systems
def compute_temporal_shape_func(
self,
props,
input_sockets,
) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_func = (
max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
| mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| offset ## Already unitless
)
match props['active_socket_set']:
case 'Pulse':
return td.GaussianPulse(
amplitude=sp.re(input_sockets['max E']),
phase=sp.im(input_sockets['max E']),
freq0=input_sockets['μ Freq'],
fwidth=input_sockets['σ Freq'],
offset=input_sockets['Offset Time'],
remove_dc_component=input_sockets['Remove DC'],
remove_dc = input_sockets['Remove DC']
has_remove_dc = not ct.FlowSignal.check(remove_dc)
if has_remove_dc:
return common_func.compose_within(
lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
remove_dc_component=remove_dc,
),
)
case 'Constant':
return td.ContinuousWave(
amplitude=sp.re(input_sockets['max E']),
phase=sp.im(input_sockets['max E']),
freq0=input_sockets['μ Freq'],
fwidth=input_sockets['σ Freq'],
offset=input_sockets['Offset Time'],
return common_func.compose_within(
lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
),
)
case 'Symbolic':
lzrange = input_sockets['t Range']
envelope_ps = input_sockets['Envelope'].func_jax
t_range = input_sockets['t Range']
envelope = input_sockets['Envelope'][ct.FlowKind.Func]
envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
return td.CustomSourceTime.from_values(
freq0=input_sockets['μ Freq'],
fwidth=input_sockets['σ Freq'],
values=envelope_ps(
lzrange.rescale_to_unit(spu.ps).realize_array.values
),
dt=input_sockets['t Range'].realize_step_size(),
has_t_range = not ct.FlowSignal.check(t_range)
has_envelope = not ct.FlowSignal.check(envelope)
has_envelope_params = not ct.FlowSignal.check(envelope_params)
if (
has_t_range
and has_envelope
and has_envelope_params
and len(envelope_params.symbols) == 1
## TODO: Allow unrealized envelope symbols
and any(
sym.physical_type is spux.PhysicalType.Time
for sym in envelope_params.symbols
)
):
envelope_time_unit = next(
sym.unit
for sym in envelope_params.symbols
if sym.physical_type is spux.PhysicalType.Time
)
# Deduce Partially Realized Envelope Function
## -> We need a pure-numerical function w/pre-realized stuff baked in.
## -> 'realize_partial' does this for us.
envelope_realizer = envelope.realize_partial(envelope_params)
# Compose w/Envelope Function
## -> First, the numerical time values must be converted.
## -> This ensures that the raw array is compatible w/the envelope.
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
## -> Because of the checks, we've guaranteed that all this is correct.
return (
common_func ## 1 | freq0, 2 | fwidth, 3 | offset
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
| t_range.scale_to_unit(envelope_time_unit).compose_within(
lambda t: envelope_realizer(t)
) ## 5
).compose_within(
lambda els: td.CustomSourceTime(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
source_time_dataset=td_TimeDataset(
values=td_TimeDataArray(
els[5], coords={'t': np.array(els[4])}
)
),
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind: Params
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Params,
# Loaded
props={'active_socket_set', 'envelope_time_unit'},
input_sockets={
'max E',
'μ Freq',
'σ Freq',
'Offset Time',
't Range',
},
input_socket_kinds={
'max E': ct.FlowKind.Params,
'μ Freq': ct.FlowKind.Params,
'σ Freq': ct.FlowKind.Params,
'Offset Time': ct.FlowKind.Params,
't Range': ct.FlowKind.Params,
},
)
def compute_temporal_shape_params(
self,
props,
input_sockets,
) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_params = max_e | mean_freq | std_freq | offset
match props['active_socket_set']:
case 'Pulse' | 'Constant':
return common_params
case 'Symbolic':
t_range = input_sockets['t Range']
has_t_range = not ct.FlowSignal.check(t_range)
if has_t_range:
return common_params | t_range | t_range
return ct.FlowSignal.FlowPending
####################

View File

@ -88,28 +88,27 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure',
kind=ct.FlowKind.Value,
# Loaded
props={'differentiable'},
input_sockets={'Medium', 'Center', 'Size'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_value(self, props, input_sockets, output_sockets) -> td.Box:
output_params = output_sockets['Structure']
def compute_value(self, input_sockets, output_sockets) -> td.Box:
"""Compute a single box structure object, given that all inputs are non-symbolic."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_center
and has_size
and has_medium
and has_output_params
and not props['differentiable']
and not output_params.symbols
):
return td.Structure(
@ -138,7 +137,8 @@ class BoxStructureNode(base.MaxwellSimNode):
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_lazy_structure(self, props, input_sockets, output_sockets) -> td.Box:
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
output_params = output_sockets['Structure']
center = input_sockets['Center']
size = input_sockets['Size']
@ -149,14 +149,8 @@ class BoxStructureNode(base.MaxwellSimNode):
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
if has_output_params and has_center and has_size and has_medium:
differentiable = props['differentiable']
if (
has_output_params
and has_center
and has_size
and has_medium
and differentiable == output_params.is_differentiable
):
if differentiable:
return (center | size | medium).compose_within(
enclosing_func=lambda els: tdadj.JaxStructure(
@ -169,6 +163,12 @@ class BoxStructureNode(base.MaxwellSimNode):
supports_jax=True,
)
return (center | size | medium).compose_within(
## TODO: Unit conversion within the composed function??
## -- We do need Tidy3D to be given ex. micrometers in particular.
## -- But the previous numerical output might not be micrometers.
## -- There must be a way to add a conversion in, without strangeness.
## -- Ex. can compose_within() take a unit system?
## -- This would require
enclosing_func=lambda els: td.Structure(
geometry=td.Box(
center=tuple(els[0].flatten()),
@ -205,14 +205,8 @@ class BoxStructureNode(base.MaxwellSimNode):
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_size and has_medium:
if props['differentiable'] == (
center.is_differentiable
and size.is_differentiable
and medium.is_differentiable
):
return center | size | medium
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
####################
# - Events: Preview
@ -226,6 +220,7 @@ class BoxStructureNode(base.MaxwellSimNode):
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
"""Mark the managed preview object when recursively linked to a viewer."""
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
@ -245,10 +240,14 @@ class BoxStructureNode(base.MaxwellSimNode):
)
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier
center = input_sockets['Center']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
## TODO: There are strategies for handling examples of symbol values.
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{

View File

@ -43,17 +43,28 @@ class SocketDef(pyd.BaseModel, abc.ABC):
"""
socket_type: ct.SocketType
active_kind: typ.Literal[
ct.FlowKind.Value,
ct.FlowKind.Array,
ct.FlowKind.Range,
ct.FlowKind.Func,
] = ct.FlowKind.Value
####################
# - Socket Interaction
####################
def preinit(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Pre-initialize a real Blender node socket from this socket definition.
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
# log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
bl_socket.reset_instance_id()
bl_socket.regenerate_dynamic_field_persistance()
log.debug('%s: End Socket Preinit', bl_socket.bl_label)
bl_socket.active_kind = self.active_kind
# log.debug('%s: End Socket Preinit', bl_socket.bl_label)
def postinit(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Pre-initialize a real Blender node socket from this socket definition.
@ -61,12 +72,12 @@ class SocketDef(pyd.BaseModel, abc.ABC):
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
# log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
bl_socket.is_initializing = False
bl_socket.on_active_kind_changed()
bl_socket.on_socket_props_changed(set(bl_socket.blfields))
bl_socket.on_data_changed(set(ct.FlowKind))
log.debug('%s: End Socket Postinit', bl_socket.bl_label)
# log.debug('%s: End Socket Postinit', bl_socket.bl_label)
@abc.abstractmethod
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
@ -76,6 +87,43 @@ class SocketDef(pyd.BaseModel, abc.ABC):
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
####################
# - Comparison
####################
def compare(self, bl_socket: bpy.types.NodeSocket) -> bool:
"""Whether this `SocketDef` can be considered to uniquely define the given `bl_socket`.
The general criteria for "uniquely defines" is whether **the same `bl_socket`** could be created using this `SocketDef`.
The extent to which user-altered properties are considered in this regard is a matter of taste, encapsulated entirely within `self.local_compare()`.
Notes:
Used when determining whether to replace sockets with newer variants when synchronizing changes.
**NOTE**: Removing/replacing loose input sockets
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
return (
bl_socket.socket_type is self.socket_type
and bl_socket.active_kind is self.active_kind
and self.local_compare(bl_socket)
)
def local_compare(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Compare this `SocketDef` to an established `bl_socket` in a manner specific to the node.
Notes:
Run by `self.compare()`.
Optionally overriden by individual sockets.
When not overridden, it will always return `False`, indicating that the socket is _never_ uniquely defined by this `SocketDef`.
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
return False
####################
# - Serialization
####################
@ -426,8 +474,34 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Parameters:
socket_kinds: The altered `ct.FlowKind`s flowing through.
"""
# Run Socket Callbacks
self.on_socket_data_changed(socket_kinds)
# Mark Active FlowKind Links as Invalid
## -> Mark link as invalid (very red) if a FlowSignal is traveling.
## -> This helps explain why whatever isn't working isn't working.
## -> TODO: We need a different approach.
# log.debug(
# '[%s] Checking FlowKind Validity (socket_kinds=%s)',
# self.name,
# str(socket_kinds),
# )
# if self.is_linked and not self.is_output:
# link = self.links[0]
# linked_flow = self.compute_data(kind=self.active_kind)
# if (
# link.is_valid
# and self.active_kind in socket_kinds
# and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
# ):
# node_tree = self.id_data
# node_tree.report_link_validity(link, False)
# elif not link.is_valid:
# node_tree = self.id_data
# node_tree.report_link_validity(link, True)
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
@ -479,7 +553,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows.
"""
# log.debug(
# '[%s] [%s] Triggered (socket_kinds=%s)',
# '[%s] [%s] Socket Triggered (socket_kinds=%s)',
# self.name,
# event,
# str(socket_kinds),
@ -757,7 +831,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
# Return Single Value / List of Values
## -> Multi-input sockets are not yet supported.
## -> Multi-input sockets are not (yet) supported.
if linked_values:
return linked_values[0]
@ -891,10 +965,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# FlowKind Draw Row
col = row.column(align=True)
{
ct.FlowKind.Capabilities: lambda *_: None,
ct.FlowKind.Previews: lambda *_: None,
ct.FlowKind.Value: self.draw_value,
ct.FlowKind.Array: self.draw_array,
ct.FlowKind.Range: self.draw_lazy_range,
ct.FlowKind.Func: self.draw_lazy_func,
ct.FlowKind.Params: lambda *_: None,
ct.FlowKind.Info: lambda *_: None,
}[self.active_kind](col)
# Info Drawing

View File

@ -51,6 +51,16 @@ class BoolBLSocket(base.MaxwellSimSocket):
def value(self, value: bool) -> None:
self.raw_value = value
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
####################
# - Socket Configuration

View File

@ -130,6 +130,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'physical_type',
'unit',
'size',
'value',
}
)
def output_sym(self) -> sim_symbols.SimSymbol | None:
@ -140,13 +141,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
Raises:
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
"""
if self.symbols:
if self.active_kind in [ct.FlowKind.Value, ct.FlowKind.Func]:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
return self._parse_expr_symbol(
self._parse_expr_str(self.raw_value_spstr)
)
if self.active_kind is ct.FlowKind.Range:
case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
exclude_zero=(
not self.value.is_zero
if self.value.is_zero is not None
else False
),
## TODO: Does this work for matrix elements?
)
case ct.FlowKind.Range if self.symbols:
## TODO: Support RangeFlow
## -- It's hard; we need a min-span set over bound domains.
## -- We... Don't use this anywhere. Yet?
@ -159,20 +176,37 @@ class ExprBLSocket(base.MaxwellSimSocket):
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
raise NotImplementedError(msg)
raise NotImplementedError
case ct.FlowKind.Range if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
rows=self.lazy_range.steps,
cols=1,
exclude_zero=not self.lazy_range.is_always_nonzero,
)
####################
# - Value|Range Swapper
####################
use_value_range_swapper: bool = bl_cache.BLField(False)
selected_value_range: ct.FlowKind = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_range(),
)
def _value_or_range(self):
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
]
####################
# - Symbols
####################
lazy_range_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
@ -343,7 +377,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
See `MaxwellSimTree` for more detail on the link callbacks.
"""
## NODE: Depends on suppressed on_prop_changed
## NOTE: Depends on suppressed on_prop_changed
if ct.FlowKind.Info in socket_kinds:
info = self.compute_data(kind=ct.FlowKind.Info)
@ -371,7 +405,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
See `MaxwellSimTree` for more detail on the link callbacks.
"""
## NODE: Depends on suppressed on_prop_changed
## NOTE: Depends on suppressed on_prop_changed
if ('selected_value_range', 'invalidate') in cleared_blfields:
self.active_kind = self.selected_value_range
self.on_active_kind_changed()
# Conditional Unit-Conversion
## -> This is niche functionality, but the only way to convert units.
@ -757,7 +794,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
@bl_cache.cached_bl_property(
depends_on={
'value',
'symbols',
'sorted_sp_symbols',
'sorted_symbols',
'output_sym',
@ -769,83 +805,88 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
"""
# Symbolic
## -> `self.value` is guaranteed to be an expression with unknowns.
## -> The function computes `self.value` with unknowns as arguments.
if self.symbols:
value = self.value
has_value = not ct.FlowSignal.check(value)
output_sym = self.output_sym
if output_sym is not None and has_value:
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if (
self.sorted_symbols and not ct.FlowSignal.check(self.value)
):
return ct.FuncFlow(
func=sp.lambdify(
self.sorted_sp_symbols,
output_sym.conform(value, strip_unit=True),
self.output_sym.conform(self.value, strip_unit=True),
'jax',
),
func_args=list(self.sorted_symbols),
func_output=self.output_sym,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
# Constant
## -> When a `self.value` has no unknowns, use a dummy function.
## -> ("Dummy" as in returns the same argument that it takes).
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
## -> Generally only useful for operations with other expressions.
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
@bl_cache.cached_bl_property(depends_on={'sorted_symbols'})
def is_differentiable(self) -> bool:
"""Whether all symbols are differentiable.
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
If there are no symbols, then there is nothing to differentiate, and thus the expression is differentiable.
"""
if not self.sorted_symbols:
return True
return all(
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
for sym in self.sorted_symbols
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym', 'value'})
return ct.FlowSignal.FlowPending
@bl_cache.cached_bl_property(
depends_on={'sorted_symbols', 'output_sym', 'value', 'lazy_range'}
)
def params(self) -> ct.ParamsFlow:
"""Returns parameter symbols/values to accompany `self.lazy_func`.
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
"""
# Symbolic
## -> The Expr socket does not declare actual values for the symbols.
## -> They should be realized later, ex. in a Viz node.
## -> Therefore, we just dump the symbols. Easy!
## -> NOTE: func_args must have the same symbol order as was lambdified.
if self.sorted_symbols:
output_sym = self.output_sym
if output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.ParamsFlow(
arg_targets=list(self.sorted_symbols),
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
symbols=self.sorted_symbols,
is_differentiable=self.is_differentiable,
symbols=set(self.sorted_symbols),
)
return ct.FlowSignal.FlowPending
# Constant
## -> Simply pass self.value verbatim as a function argument.
## -> Easy dice, easy life!
case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.value],
is_differentiable=self.is_differentiable,
)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.output_sym.sp_symbol_matsym],
symbols={self.output_sym},
).realize_partial({self.output_sym: self.lazy_range})
return ct.FlowSignal.FlowPending
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
def info(self) -> ct.InfoFlow:
r"""Returns parameter symbols/values to accompany `self.lazy_func`.
@ -858,22 +899,34 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
"""
# Constant
## -> The input SimSymbols become continuous dimensional indices.
## -> All domain validity information is defined on the SimSymbol keys.
if self.sorted_symbols:
output_sym = self.output_sym
if output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym,
)
return ct.FlowSignal.FlowPending
# Constant
## -> We only need the output symbol to describe the raw data.
case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(output=self.output_sym)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(
dims={self.output_sym: self.lazy_range},
output=self.output_sym.update(rows=1),
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind: Capabilities
####################
@ -1039,6 +1092,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI.
"""
if self.use_value_range_swapper:
col.prop(self, self.blfields['selected_value_range'], text='')
if self.symbols:
col.prop(self, self.blfields['raw_value_spstr'], text='')
@ -1097,6 +1153,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
As such, `self.steps` won't be exposed in the UI.
"""
if self.use_value_range_swapper:
col.prop(self, self.blfields['selected_value_range'], text='')
if self.symbols:
col.prop(self, self.blfields['raw_min_spstr'], text='')
col.prop(self, self.blfields['raw_max_spstr'], text='')
@ -1198,13 +1257,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
# - Socket Configuration
####################
class ExprSocketDef(base.SocketDef):
"""Interface for defining an `ExprSocket`."""
socket_type: ct.SocketType = ct.SocketType.Expr
active_kind: typ.Literal[
ct.FlowKind.Value,
ct.FlowKind.Range,
ct.FlowKind.Func,
] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
use_value_range_swapper: bool = False
# Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -1458,7 +1515,7 @@ class ExprSocketDef(base.SocketDef):
# Check ActiveKind and Size
## -> NOTE: This doesn't protect against dynamic changes to either.
if (
self.active_kind == ct.FlowKind.Range
self.active_kind is ct.FlowKind.Range
and self.size is not spux.NumberSize1D.Scalar
):
msg = "Can't have a non-Scalar size when Range is set as the active kind."
@ -1504,9 +1561,9 @@ class ExprSocketDef(base.SocketDef):
# - Initialization
####################
def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind
bl_socket.output_name = self.output_name
bl_socket.use_linked_capabilities = True
bl_socket.use_value_range_swapper = self.use_value_range_swapper
# Socket Interface
## -> Recall that auto-updates are turned off during init()
@ -1543,6 +1600,25 @@ class ExprSocketDef(base.SocketDef):
# Info Draw
bl_socket.use_info_draw = True
def local_compare(self, bl_socket: ExprBLSocket) -> None:
"""Determine whether an updateable socket should be re-initialized from this `SocketDef`."""
def cmp(attr: str):
return getattr(bl_socket, attr) == getattr(self, attr)
return (
bl_socket.use_linked_capabilities
and cmp('output_name')
and cmp('use_value_range_swapper')
and cmp('size')
and cmp('mathtype')
and cmp('physical_type')
and cmp('show_func_ui')
and cmp('show_info_columns')
and cmp('show_name_selector')
and bl_socket.use_info_draw
)
####################
# - Blender Registration

View File

@ -117,6 +117,16 @@ class MaxwellBoundCondsBLSocket(base.MaxwellSimSocket):
),
)
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
####################
# - Socket Configuration

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from ... import contracts as ct
from .. import base
@ -32,6 +34,9 @@ class MaxwellFDTDSimSocketDef(base.SocketDef):
def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None:
pass
def local_compare(self, _: MaxwellFDTDSimBLSocket) -> None:
return True
####################
# - Blender Registration

View File

@ -73,7 +73,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
def value(self, eps_rel: tuple[float, float]) -> None:
self.eps_rel = eps_rel
@bl_cache.cached_bl_property(depends_on={'value', 'differentiable'})
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
@ -82,7 +82,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
@bl_cache.cached_bl_property(depends_on={'differentiable'})
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow(is_differentiable=self.differentiable)
return ct.ParamsFlow()
####################
# - UI

View File

@ -29,11 +29,11 @@ class MaxwellMonitorBLSocket(base.MaxwellSimSocket):
class MaxwellMonitorSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor
is_list: bool = False
def init(self, bl_socket: MaxwellMonitorBLSocket) -> None:
if self.is_list:
bl_socket.active_kind = ct.FlowKind.Array
pass
def local_compare(self, _: MaxwellMonitorBLSocket) -> None:
return True
####################

View File

@ -55,6 +55,16 @@ class MaxwellSimGridBLSocket(base.MaxwellSimSocket):
min_steps_per_wvl=self.min_steps_per_wl,
)
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
####################
# - Socket Configuration

View File

@ -29,11 +29,11 @@ class MaxwellSourceBLSocket(base.MaxwellSimSocket):
class MaxwellSourceSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellSource
is_list: bool = False
def init(self, bl_socket: MaxwellSourceBLSocket) -> None:
if self.is_list:
bl_socket.active_kind = ct.FlowKind.Array
pass
def local_compare(self, _: MaxwellSourceBLSocket) -> None:
return True
####################

View File

@ -29,11 +29,11 @@ class MaxwellStructureBLSocket(base.MaxwellSimSocket):
class MaxwellStructureSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellStructure
is_list: bool = False
def init(self, bl_socket: MaxwellStructureBLSocket) -> None:
if self.is_list:
bl_socket.active_kind = ct.FlowKind.Array
pass
def local_compare(self, _: MaxwellStructureBLSocket) -> None:
return True
####################

View File

@ -16,10 +16,10 @@
"""Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes."""
from ..keyed_cache import KeyedCache, keyed_cache
from .bl_field import BLField
from .bl_prop import BLProp, BLPropType
from .cached_bl_property import CachedBLProperty, cached_bl_property
from .keyed_cache import KeyedCache, keyed_cache
from .managed_cache import invalidate_nonpersist_instance_id
from .signal import Signal

View File

@ -21,6 +21,7 @@ from types import MappingProxyType
import bpy
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils.keyed_cache import keyed_cache
InstanceID: typ.TypeAlias = str ## Stringified UUID4
@ -220,11 +221,14 @@ class BLInstance:
for str_search_prop_name in self.blfields_str_search:
setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch)
@keyed_cache(
exclude={'self'}, ## No dynamic elements of 'self' can be used.
)
def trace_blfields_to_clear(
self,
prop_name: str,
prev_blfields_to_clear: list[
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']]
prev_blfields_to_clear: tuple[
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']], ...
] = (),
) -> list[str]:
"""Invalidates all properties that depend on `prop_name`.
@ -239,7 +243,7 @@ class BLInstance:
All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`).
"""
if prev_blfields_to_clear:
blfields_to_clear = prev_blfields_to_clear.copy()
blfields_to_clear = list(prev_blfields_to_clear)
else:
blfields_to_clear = []
@ -268,7 +272,7 @@ class BLInstance:
if dst_prop_name in self.blfields:
blfields_to_clear += self.trace_blfields_to_clear(
dst_prop_name,
prev_blfields_to_clear=blfields_to_clear,
prev_blfields_to_clear=tuple(blfields_to_clear),
)
match (bool(prev_blfields_to_clear), bool(blfields_to_clear)):
@ -297,7 +301,7 @@ class BLInstance:
## -> As such, deduplication would not be wrong, just extraneous.
## -> Since invalidation is in a hot-loop, don't do such things.
case (True, True):
return blfields_to_clear
return list(reversed(dict.fromkeys(reversed(blfields_to_clear))))
def clear_blfields_after(self, prop_name: str) -> list[str]:
"""Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`.

View File

@ -17,6 +17,7 @@
"""Useful image processing operations for use in the addon."""
import enum
import functools
import typing as typ
import jax
@ -26,7 +27,6 @@ import matplotlib
import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg
import matplotlib.figure
import numpy as np
import seaborn as sns
from blender_maxwell import contracts as ct
@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
####################
# - MPL Helpers
####################
# @functools.lru_cache(maxsize=16)
@functools.lru_cache(maxsize=4)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'

View File

@ -18,11 +18,15 @@ import functools
import inspect
import typing as typ
from blender_maxwell.utils import bl_instance, logger, serialize
from blender_maxwell.utils import logger, serialize
log = logger.get(__name__)
class BLInstance(typ.Protocol):
instance_id: str
class KeyedCache:
def __init__(
self,
@ -75,8 +79,8 @@ class KeyedCache:
def __get__(
self,
bl_instance: bl_instance.BLInstance | None,
owner: type[bl_instance.BLInstance],
bl_instance: BLInstance | None,
owner: type[BLInstance],
) -> typ.Callable:
_func = functools.partial(self, bl_instance)
_func.invalidate = functools.partial(
@ -110,7 +114,7 @@ class KeyedCache:
def invalidate(
self,
bl_instance: bl_instance.BLInstance | None,
bl_instance: BLInstance | None,
**arguments: dict[str, typ.Any],
) -> dict[str, typ.Any]:
# Determine Wildcard Arguments

View File

@ -264,13 +264,16 @@ class SimSymbol(pyd.BaseModel):
interval_closed_im: tuple[bool, bool] = (False, False)
####################
# - Labels
# - Core
####################
@functools.cached_property
def name(self) -> str:
"""Usable name for the symbol."""
return self.sym_name.name
####################
# - Labels
####################
@functools.cached_property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
@ -307,6 +310,8 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property
def plot_label(self) -> str:
"""Pretty plot-oriented label."""
if self.unit is None:
return self.name_pretty
return f'{self.name_pretty} ({self.unit_label})'
####################
@ -420,6 +425,11 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property
def is_nonzero(self) -> bool:
"""Whether or not the value of this symbol can ever be $0$.
Notes:
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
"""
if self.exclude_zero:
return True
@ -441,6 +451,18 @@ class SimSymbol(pyd.BaseModel):
)
return check_real_domain(self.domain)
@functools.cached_property
def can_diff(self) -> bool:
"""Whether this symbol can be used as the input / output variable when differentiating."""
# Check Constants
## -> Constants (w/pinned values) are never differentiable.
if self.is_constant:
return False
# TODO: Discontinuities (especially across 0)?
return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
####################
# - Properties
####################
@ -664,8 +686,10 @@ class SimSymbol(pyd.BaseModel):
res = spux.strip_unit_system(sp_obj)
# Broadcast Expansion
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
if (self.rows > 1 or self.cols > 1) and not isinstance(
res, sp.MatrixBase | sp.MatrixSymbol
):
res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
return res
@ -753,7 +777,9 @@ class SimSymbol(pyd.BaseModel):
unit = None
# Rows/Cols from Expr (if Matrix)
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
rows, cols = (
expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1)
)
return SimSymbol(
sym_name=sym_name,