feat: various sym-flow modifications

main
Sofus Albert Høgsbro Rose 2024-05-30 18:41:06 +02:00
parent 830b316e01
commit 38e70a60d3
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
40 changed files with 2226 additions and 835 deletions

View File

@ -14,12 +14,12 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools import functools
import typing as typ import typing as typ
import jaxtyping as jtyp import jaxtyping as jtyp
import numpy as np import numpy as np
import pydantic as pyd
import sympy as sp import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
@ -29,8 +29,7 @@ log = logger.get(__name__)
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong. # TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
@dataclasses.dataclass(frozen=True, kw_only=True) class ArrayFlow(pyd.BaseModel):
class ArrayFlow:
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking. """A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing. While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
@ -41,7 +40,9 @@ class ArrayFlow:
None if unitless. None if unitless.
""" """
values: jtyp.Shaped[jtyp.Array, '...'] model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
unit: spux.Unit | None = None unit: spux.Unit | None = None
is_sorted: bool = False is_sorted: bool = False

View File

@ -18,6 +18,7 @@ import enum
import functools import functools
import typing as typ import typing as typ
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty from blender_maxwell.utils.staticproperty import staticproperty
@ -99,6 +100,17 @@ class FlowKind(enum.StrEnum):
def to_icon(_: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
return '' return ''
@property
def name(self) -> str:
return FlowKind.to_name(self)
@property
def icon(self) -> str:
return FlowKind.to_icon(self)
def bl_enum_element(self, i) -> BLEnumElement:
return (str(self), self.name, self.name, self.icon, i)
#################### ####################
# - Static Properties # - Static Properties
#################### ####################
@ -162,7 +174,7 @@ class FlowKind(enum.StrEnum):
def socket_shape(self) -> str: def socket_shape(self) -> str:
"""Return the socket shape associated with this `FlowKind`. """Return the socket shape associated with this `FlowKind`.
**ONLY** valid for `FlowKind`s that can be considered "active". Should generally only be used with `active_kinds`.
Raises: Raises:
ValueError: If this `FlowKind` cannot ever be considered "active". ValueError: If this `FlowKind` cannot ever be considered "active".
@ -172,7 +184,7 @@ class FlowKind(enum.StrEnum):
FlowKind.Array: 'SQUARE', FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE', FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND', FlowKind.Func: 'DIAMOND',
}[self] }.get(self, 'CIRCLE')
#################### ####################
# - Class Methods # - Class Methods

View File

@ -14,40 +14,17 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses r"""Implements the core of the math system via `FuncFlow`, which allows high-performance, fully-expressive workflows with data that can be "very large", and/or whose input parameters are not yet fully known.
import functools
import typing as typ
from types import MappingProxyType
import jax
import jaxtyping as jtyp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .info import InfoFlow
from .lazy_range import RangeFlow
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class FuncFlow:
r"""Defines a flow of data as incremental function composition.
For specific math system usage instructions, please consult the documentation of relevant nodes.
# Introduction # Introduction
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**. When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
Doing so has several advantages: Doing so has several advantages:
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy. - **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy, greatly boosting the will to experiment and ultimate productivity.
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact. - **Symbolic**: Since no numerical math is being done yet, we can inject symbolic variables at-will, enabling effortless ex. parameter sweeping, band-structure generation, differentiable can choose to keep our input parameters as symbolic variables with no performance impact.
- **Performant**: Since no operations are happening, the UI feels fast and snappy. - **Performant**: Since the data pipeline is built as a single function w/o side effects, that function can be often be JIT-optimized for highly-optimized execution on the instruction sets used by modern massively-parallel devices, like modern CPUs (SSE/AVX), GPUs (ex. PTX), and HPC clusters (network-sharding to ARM/x86).
The result is a math system optimized for the analysis typically needed in electrodynamic contexts, prioritizing clarity and flexibility at soft-real-time, even with gigabytes of data on relatively weak hardware.
## Strongly Related FlowKinds ## Strongly Related FlowKinds
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel: For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
@ -238,6 +215,35 @@ class FuncFlow:
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes. This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
But above all, we hope that this math system is fun, practical, and maybe even interesting. But above all, we hope that this math system is fun, practical, and maybe even interesting.
"""
import functools
import typing as typ
from types import MappingProxyType
import jax
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .info import InfoFlow
from .lazy_range import RangeFlow
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
class FuncFlow(pyd.BaseModel):
r"""Defines a flow of data as incremental function composition.
For theoretical information, please see the documentation of this module.
For specific math system usage instructions, please consult the documentation of relevant nodes.
Attributes: Attributes:
func: The function that generates the represented value. func: The function that generates the represented value.
@ -247,14 +253,16 @@ class FuncFlow:
See the documentation of `self.func_jax()`. See the documentation of `self.func_jax()`.
""" """
model_config = pyd.ConfigDict(frozen=True)
func: LazyFunction func: LazyFunction
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list) func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field( func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
default_factory=dict func_output: sim_symbols.SimSymbol | None = None
)
supports_jax: bool = False supports_jax: bool = False
concatenated: bool = False is_concatenated: bool = False
#################### ####################
# - Functions # - Functions
@ -318,6 +326,7 @@ class FuncFlow:
{} {}
), ),
) -> typ.Self: ) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
if self.supports_jax: if self.supports_jax:
return self.func_jax( return self.func_jax(
*params.scaled_func_args(symbol_values), *params.scaled_func_args(symbol_values),
@ -371,14 +380,55 @@ class FuncFlow:
return data | {info.output: self.realize(params, symbol_values=symbol_values)} return data | {info.output: self.realize(params, symbol_values=symbol_values)}
def realize_partial(
self, params: ParamsFlow
) -> typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
jtyp.Inexact[jtyp.Array, '...'],
]:
"""Create a purely-numerical function, which takes only numerical.
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
By using `realize_partial()`, two things are ensured:
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
Notes:
Be **very careful about units**.
Ideally, the bottom function should use `.scale_to_unit()` before invoking `.compose_within()` with the output of this function.
"""
pre_realized_syms = list(
params.realize_symbols(params.realized_symbols, allow_partial=True).values()
)
def realizer(
*sym_args: int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
) -> jtyp.Inexact[jtyp.Array, '...']:
return self.func(
*[
func_arg_n(*sym_args, *pre_realized_syms)
for func_arg_n in params.func_args_n
],
**{
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
},
)
return realizer
#################### ####################
# - Composition Operations # - Operations
#################### ####################
def compose_within( def compose_within(
self, self,
enclosing_func: LazyFunction, enclosing_func: LazyFunction,
enclosing_func_args: list[type] = (), enclosing_func_args: list[sim_symbols.SimSymbol] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
enclosing_func_output: sim_symbols.SimSymbol | None = None,
supports_jax: bool = False, supports_jax: bool = False,
) -> typ.Self: ) -> typ.Self:
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it. """Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
@ -415,6 +465,10 @@ class FuncFlow:
Returns: Returns:
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function). A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
""" """
## TODO: Support unit system conversion at the point of composition.
## -- This may require us to track the units of the function output.
## TODO: Support JAX-evaluation when jax support changes from True to False.
## -- This would allow big data flows to compose performantly as arguments into non-JAX functions.
return FuncFlow( return FuncFlow(
func=lambda *args, **kwargs: enclosing_func( func=lambda *args, **kwargs: enclosing_func(
self.func( self.func(
@ -426,6 +480,7 @@ class FuncFlow:
), ),
func_args=self.func_args + list(enclosing_func_args), func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
func_output=enclosing_func_output,
supports_jax=self.supports_jax and supports_jax, supports_jax=self.supports_jax and supports_jax,
) )
@ -472,7 +527,7 @@ class FuncFlow:
*list(args[: len(self.func_args)]), *list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs}, **{k: v for k, v in kwargs.items() if k in self.func_kwargs},
) )
if not self.concatenated: if not self.is_concatenated:
return (ret,) return (ret,)
return ret return ret
@ -487,5 +542,83 @@ class FuncFlow:
func_args=self.func_args + other.func_args, func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs, func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax, supports_jax=self.supports_jax and other.supports_jax,
concatenated=True, is_concatenated=True,
) )
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
`unit` must be manually guaranteed to be compatible with `self.unit`.
"""
if self.func_output is not None:
# Retrieve Output Unit
output_unit = self.func_output.unit
# Compile Efficient Unit-Conversion Function
a = self.func_output.mathtype.sp_symbol_a
unit_convert_expr = (
spux.scale_to_unit(a * output_unit, unit)
if self.func_output.unit is not None
else a
)
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
# Compose Unit-Converted FuncFlow
return self.compose_within(
enclosing_func=unit_convert_func,
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=unit),
)
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)
def scale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
Using `self.output_symbol`, which tracks the units of the output, we can determine a scaling factor to multiply the (numerical) function output by in order to conform it to the given unit system.
In general, **don't use this**.
Any superfluous numerical operations in a data pipeline can enhance instabilities and interfere with JIT-optimization (floating-point arithmetic isn't commutative, for example).
However, occasionally, we need to "intercept" a lazy data flow, for example when realizing a `FlowKind.Value` that doesn't understand symbols or units - but which only accepts a float/complex scalar/array with pre-determined unit convention.
For this purpose alone, this method is provided to pre-scale a `FuncFlow`, just before using `realize()` / `__or__` and then `realize()`.
**To encourage proper usage** (and ease implementation), the output unit in `self.func_output` of the output will be reset to `None` - indicating that the output can only be handled as a unitless scalar w/semantic meaning tracked elsewhere.
Notes:
**ONLY** use with output types that support meaningful arbitrary multiplication.
A scale-only sympy expression will be used to produce an optimized JAX function of a single variable, which will then be composed onto the existing `FuncFlow`.
Parameters:
unit_system: The unit system to conform the function output to.
Returns:
A new `FuncFlow` that conforms to the new unit, but is itself now considered unitless.
"""
if self.func_output is not None:
# Retrieve Output Unit
output_unit = self.func_output.unit
# Compile Efficient Unit-Conversion Function
a = self.func_output.mathtype.sp_symbol_a
unit_convert_expr = (
spux.strip_unit_system(
spux.convert_to_unit_system(a * output_unit, unit_system)
)
if self.func_output.unit is not None
else a
)
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
# Compose Unit-Converted FuncFlow
return self.compose_within(
enclosing_func=unit_convert_func,
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=None),
)
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)

View File

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import enum import enum
import functools import functools
import typing as typ import typing as typ
from fractions import Fraction
from types import MappingProxyType from types import MappingProxyType
import jax.numpy as jnp import jax.numpy as jnp
import jaxtyping as jtyp import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols from blender_maxwell.utils import logger, sim_symbols
@ -61,8 +59,7 @@ class ScalingMode(enum.StrEnum):
return '' return ''
@dataclasses.dataclass(frozen=True, kw_only=True) class RangeFlow(pyd.BaseModel):
class RangeFlow:
r"""Represents a finite spaced array using symbolic boundary expressions. r"""Represents a finite spaced array using symbolic boundary expressions.
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
@ -92,8 +89,10 @@ class RangeFlow:
symbols: Set of variables from which `start` and/or `stop` are determined. symbols: Set of variables from which `start` and/or `stop` are determined.
""" """
start: spux.ScalarUnitlessComplexExpr model_config = pyd.ConfigDict(frozen=True)
stop: spux.ScalarUnitlessComplexExpr
start: spux.ScalarUnitlessRealExpr
stop: spux.ScalarUnitlessRealExpr
steps: int = 0 steps: int = 0
scaling: ScalingMode = ScalingMode.Lin scaling: ScalingMode = ScalingMode.Lin
@ -102,7 +101,7 @@ class RangeFlow:
symbols: frozenset[sim_symbols.SimSymbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
# Helper Attributes # Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
#################### ####################
# - SimSymbol Interop # - SimSymbol Interop
@ -218,14 +217,26 @@ class RangeFlow:
) )
return combined_mathtype return combined_mathtype
@property @functools.cached_property
def ideal_midpoint(self) -> spux.SympyExpr: def ideal_midpoint(self) -> spux.SympyExpr:
return (self.stop + self.start) / 2 return (self.stop + self.start) / 2
@property @functools.cached_property
def ideal_range(self) -> spux.SympyExpr: def ideal_range(self) -> spux.SympyExpr:
return self.stop - self.start return self.stop - self.start
@functools.cached_property
def ideal_step_size(self) -> spux.SympyExpr:
return self.ideal_range / (self.steps - 1)
@functools.cached_property
def is_always_nonzero(self) -> spux.SympyExpr:
if self.start > 0 or self.stop < 0:
return True
is_zero = (self.start % self.ideal_step_size).is_zero
return is_zero if is_zero is not None else False
#################### ####################
# - Methods # - Methods
#################### ####################
@ -452,7 +463,7 @@ class RangeFlow:
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{} {}
), ),
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]: ) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
"""Realize **all** input symbols to the `RangeFlow`. """Realize **all** input symbols to the `RangeFlow`.
Parameters: Parameters:
@ -480,7 +491,7 @@ class RangeFlow:
raise NotImplementedError(msg) raise NotImplementedError(msg)
realized_syms |= {sym: v} realized_syms |= {sym: v}
return realized_syms
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})' msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg) raise ValueError(msg)

View File

@ -14,13 +14,13 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools import functools
import typing as typ import typing as typ
from fractions import Fraction from fractions import Fraction
from types import MappingProxyType from types import MappingProxyType
import jaxtyping as jtyp import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
@ -28,31 +28,35 @@ from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow from .array import ArrayFlow
from .expr_info import ExprInfo from .expr_info import ExprInfo
from .flow_kinds import FlowKind
from .lazy_range import RangeFlow from .lazy_range import RangeFlow
log = logger.get(__name__) log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True) class ParamsFlow(pyd.BaseModel):
class ParamsFlow:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns: Returns:
All symbols valid for use in the expression. All symbols valid for use in the expression.
""" """
arg_targets: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list) model_config = pyd.ConfigDict(frozen=True)
kwarg_targets: list[str, sim_symbols.SimSymbol] = dataclasses.field(
default_factory=dict
)
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
realized_symbols: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = pyd.Field(default_factory=dict)
is_differentiable: bool = False @functools.cached_property
def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
"""Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
return {sym for sym in self.symbols if sym.can_diff}
#################### ####################
# - Symbols # - Symbols
@ -78,6 +82,27 @@ class ParamsFlow:
""" """
return [sym.sp_symbol_matsym for sym in self.sorted_symbols] return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
@functools.cached_property
def all_sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
key_func = lambda sym: sym.name # noqa: E731
return sorted(self.symbols, key=key_func) + sorted(
self.realized_symbols.keys(), key=key_func
)
@functools.cached_property
def all_sorted_sp_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol_matsym for sym in self.all_sorted_symbols]
#################### ####################
# - JIT'ed Callables for Numerical Function Arguments # - JIT'ed Callables for Numerical Function Arguments
#################### ####################
@ -101,7 +126,7 @@ class ParamsFlow:
""" """
return [ return [
sp.lambdify( sp.lambdify(
self.sorted_sp_symbols, self.all_sorted_sp_symbols,
target_sym.conform(func_arg, strip_unit=True), target_sym.conform(func_arg, strip_unit=True),
'jax', 'jax',
) )
@ -127,7 +152,7 @@ class ParamsFlow:
""" """
return { return {
key: sp.lambdify( key: sp.lambdify(
self.sorted_sp_symbols, self.all_sorted_sp_symbols,
self.kwarg_targets[key].conform(func_arg, strip_unit=True), self.kwarg_targets[key].conform(func_arg, strip_unit=True),
'jax', 'jax',
) )
@ -142,8 +167,9 @@ class ParamsFlow:
symbol_values: dict[ symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}), ] = MappingProxyType({}),
allow_partial: bool = False,
) -> dict[ ) -> dict[
sp.Symbol, sim_symbols.SimSymbol,
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
]: ]:
"""Fully realize all symbols by assigning them a value. """Fully realize all symbols by assigning them a value.
@ -160,10 +186,12 @@ class ParamsFlow:
Returns: Returns:
A dictionary almost with `.subs()`, other than `jax` arrays. A dictionary almost with `.subs()`, other than `jax` arrays.
""" """
if set(self.symbols) == set(symbol_values.keys()): if allow_partial or set(self.all_sorted_symbols) == set(symbol_values.keys()):
realized_syms = {} realized_syms = {}
for sym in self.sorted_symbols: for sym in self.all_sorted_symbols:
sym_value = symbol_values[sym] sym_value = symbol_values.get(sym)
if sym_value is None and allow_partial:
continue
if isinstance(sym_value, spux.SympyType): if isinstance(sym_value, spux.SympyType):
v = sym.scale(sym_value) v = sym.scale(sym_value)
@ -214,7 +242,9 @@ class ParamsFlow:
Parameters: Parameters:
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`). symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
""" """
realized_symbols = list(self.realize_symbols(symbol_values).values()) realized_symbols = list(
self.realize_symbols(symbol_values | self.realized_symbols).values()
)
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n] return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
def scaled_func_kwargs( def scaled_func_kwargs(
@ -227,10 +257,11 @@ class ParamsFlow:
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`. Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
""" """
realized_symbols = self.realize_symbols(symbol_values) realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
return { return {
func_arg_name: func_arg_n(**realized_symbols) func_arg_name: func_kwarg_n(**realized_symbols)
for func_arg_name, func_arg_n in self.func_kwargs_n.items() for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
} }
#################### ####################
@ -251,7 +282,7 @@ class ParamsFlow:
func_args=self.func_args + other.func_args, func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs, func_kwargs=self.func_kwargs | other.func_kwargs,
symbols=self.symbols | other.symbols, symbols=self.symbols | other.symbols,
is_differentiable=self.is_differentiable and other.is_differentiable, realized_symbols=self.realized_symbols | other.realized_symbols,
) )
def compose_within( def compose_within(
@ -261,7 +292,6 @@ class ParamsFlow:
enclosing_func_args: list[spux.SympyExpr] = (), enclosing_func_args: list[spux.SympyExpr] = (),
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}), enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(), enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
enclosing_is_differentiable: bool = False,
) -> typ.Self: ) -> typ.Self:
return ParamsFlow( return ParamsFlow(
arg_targets=self.arg_targets + list(enclosing_arg_targets), arg_targets=self.arg_targets + list(enclosing_arg_targets),
@ -269,13 +299,41 @@ class ParamsFlow:
func_args=self.func_args + list(enclosing_func_args), func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols, symbols=self.symbols | enclosing_symbols,
is_differentiable=( realized_symbols=self.realized_symbols,
self.is_differentiable
if not enclosing_symbols
else (self.is_differentiable & enclosing_is_differentiable)
),
) )
def realize_partial(
self,
symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
],
) -> typ.Self:
"""Provide a particular expression/range/array to realize some symbols.
Essentially removes symbols from `self.symbols`, and adds the symbol w/value to `self.realized_symbols`.
As a result, only the still-unrealized symbols need to be passed at the time of realization (using ex. `self.scaled_func_args()`).
Parameters:
symbol_values: The value to realize for each `SimSymbol`.
**All keys** must be identically matched to a single element of `self.symbol`.
Can be empty, in which case an identical new `ParamsFlow` will be returned.
Raises:
ValueError: If any symbol in `symbol_values`
"""
syms = set(symbol_values.keys())
if syms.issubset(self.symbols) or not syms:
return ParamsFlow(
arg_targets=self.arg_targets,
kwarg_targets=self.kwarg_targets,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
symbols=self.symbols - syms,
realized_symbols=self.realized_symbols | symbol_values,
)
msg = f'ParamsFlow: Not all partially realized symbols are defined on the ParamsFlow (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
#################### ####################
# - Generate ExprSocketDef # - Generate ExprSocketDef
#################### ####################

View File

@ -16,7 +16,7 @@
"""Declares `ManagedBLImage`.""" """Declares `ManagedBLImage`."""
# import time import time
import typing as typ import typing as typ
import bpy import bpy
@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj):
dpi: int | None = None, dpi: int | None = None,
bl_select: bool = False, bl_select: bool = False,
): ):
# times = [time.perf_counter()] times = ['START', time.perf_counter()]
# Compute Plot Dimensions # Compute Plot Dimensions
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = ( # aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
@ -277,22 +277,22 @@ class ManagedBLImage(base.ManagedObj):
# _width_inches, _height_inches, _dpi # _width_inches, _height_inches, _dpi
# ) # )
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi) fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]]) times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
# fig.clear() # fig.clear()
ax.clear() ax.clear()
# times.append(['Clear Axis', time.perf_counter() - times[0]]) times.append(['Clear Axis', time.perf_counter() - times[0]])
# Plot w/User Parameter # Plot w/User Parameter
func_plotter(ax) func_plotter(ax)
# times.append(['Plot!', time.perf_counter() - times[0]]) times.append(['Plot!', time.perf_counter() - times[0]])
# Save Figure to BytesIO # Save Figure to BytesIO
canvas.draw() canvas.draw()
# times.append(['Draw Pixels', time.perf_counter() - times[0]]) times.append(['Draw Pixels', time.perf_counter() - times[0]])
canvas_width_px, canvas_height_px = fig.canvas.get_width_height() canvas_width_px, canvas_height_px = fig.canvas.get_width_height()
# times.append(['Get Canvas Dims', time.perf_counter() - times[0]]) times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
image_data = ( image_data = (
np.float32( np.float32(
np.flipud( np.flipud(
@ -303,7 +303,7 @@ class ManagedBLImage(base.ManagedObj):
) )
/ 255 / 255
) )
# times.append(['Load Data from Canvas', time.perf_counter() - times[0]]) times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
# Optimized Write to Blender Image # Optimized Write to Blender Image
bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8') bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8')

View File

@ -15,6 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import contextlib import contextlib
import functools
import queue
import typing as typ import typing as typ
import bpy import bpy
@ -26,6 +28,14 @@ from .managed_objs.managed_bl_image import ManagedBLImage
log = logger.get(__name__) log = logger.get(__name__)
link_action_queue = queue.Queue()
def set_link_validity(link: bpy.types.NodeLink, validity: bool) -> None:
log.critical('Set %s validity to %s', str(link), str(validity))
link.is_valid = validity
#################### ####################
# - Cache Management # - Cache Management
#################### ####################
@ -45,58 +55,93 @@ class DeltaNodeLinkCache(typ.TypedDict):
class NodeLinkCache: class NodeLinkCache:
"""A pointer-based cache of node links in a node tree. """A volatile pointer-based cache of node links in a node tree.
Warnings:
Everything here is **extremely** unsafe.
Even a single mistake **will** cause a use-after-free crash of Blender.
Used perfectly, it allows for powerful features; anything less, and it's an epic liability.
Attributes: Attributes:
_node_tree: Reference to the owning node tree. _node_tree: Reference to the node tree for which this cache is valid.
link_ptrs_as_links: link_ptrs: Memory-address identifiers for all node links that currently exist in `_node_tree`.
link_ptrs: Pointers (as in integer memory adresses) to `NodeLink`s. link_ptrs_as_links: Mapping from pointers (integers) to actual `NodeLink` objects.
link_ptrs_as_links: Map from pointers to actual `NodeLink`s. **WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the source of each `NodeLink`. socket_ptrs: Memory-address identifiers for all sockets that currently exist in `_node_tree`.
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the destination of each `NodeLink`. socket_ptrs_as_sockets: Mapping from pointers (integers) to actual `NodeSocket` objects.
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
socket_ptr_refcount: The amount of links currently connected to a given socket pointer.
Used to drive the deletion of socket pointers using only knowledge about `link_ptr` removal.
link_ptrs_as_from_socket_ptrs: The pointer of the source socket, defined for every node link pointer.
link_ptrs_as_to_socket_ptrs: The pointer of the destination socket, defined for every node link pointer.
""" """
def __init__(self, node_tree: bpy.types.NodeTree): def __init__(self, node_tree: bpy.types.NodeTree):
"""Initialize the cache from a node tree. """Defines and fills the cache from a live node tree."""
Parameters:
node_tree: The Blender node tree whose `NodeLink`s will be cached.
"""
self._node_tree = node_tree self._node_tree = node_tree
# Link PTR and PTR->REF
self.link_ptrs: set[MemAddr] = set() self.link_ptrs: set[MemAddr] = set()
self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {} self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {}
# Socket PTR and PTR->REF
self.socket_ptrs: set[MemAddr] = set() self.socket_ptrs: set[MemAddr] = set()
self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {} self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {}
self.socket_ptr_refcount: dict[MemAddr, int] = {} self.socket_ptr_refcount: dict[MemAddr, int] = {}
# Link PTR -> Socket PTR
self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {} self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {}
self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {} self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {}
self.link_ptrs_invalid: set[MemAddr] = set()
# Fill Cache # Fill Cache
self.regenerate() self.regenerate()
def remove_link(self, link_ptr: MemAddr) -> None: def remove_link(self, link_ptr: MemAddr) -> None:
"""Removes a link pointer from the cache, indicating that the link doesn't exist anymore. """Reports a link as removed, causing it to be removed from the cache.
This **must** be run whenever a node link is deleted.
**Failure to to so WILL result in segmentation fault** at an unknown future time.
In particular, the following actions are taken:
- The entry in `self.link_ptrs_as_links` is deleted.
- Any entry in `self.link_ptrs_invalid` is deleted (if exists).
Notes: Notes:
- **DOES NOT** remove PTR->REF dictionary entries Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
- Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`. In some cases, this may be desirable, ex. for internal methods that shouldn't trip a `DataChanged` flow event.
- This **must** be done whenever a node link is deleted.
- Failure to do so may result in a segmentation fault at arbitrary future time.
Parameters: Parameters:
link_ptr: Pointer to remove from the cache. link_ptr: The pointer (integer) to remove from the cache.
Raises:
KeyError: If `link_ptr` is not a member of either `self.link_ptrs`, or of `self.link_ptrs_as_links`.
""" """
self.link_ptrs.remove(link_ptr) self.link_ptrs.remove(link_ptr)
self.link_ptrs_as_links.pop(link_ptr) self.link_ptrs_as_links.pop(link_ptr)
if link_ptr in self.link_ptrs_invalid:
self.link_ptrs_invalid.remove(link_ptr)
def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None: def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None:
"""Removes a single pointer's reference to its from/to sockets.""" """Deassociate from all sockets referenced by a link, respecting the socket pointer reference-count.
The `NodeLinkCache` stores references to all socket pointers referenced by any link.
Since several links can be associated with each socket, we must keep a "reference count" per-socket.
When the "reference count" drops to zero, then there are no longer any `NodeLink`s that refer to it, and therefore it should be removed from the `NodeLinkCache`.
This method facilitates that process by:
- Extracting (with removal) the from / to socket pointers associated with `link_ptr`.
- If the socket pointer has a reference count of `1`, then it is **completely removed**.
- If the socket pointer has a reference count of `>1`, then the reference count is decremented by `1`.
Notes:
In general, this should be called together with `remove_link`.
However, in certain cases, this process also needs to happen by itself.
Parameters:
link_ptr: The pointer (integer) to remove from the cache.
"""
# Remove Socket Pointers
from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None) from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None)
to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None) to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None)
@ -113,31 +158,40 @@ class NodeLinkCache:
self.socket_ptr_refcount[socket_ptr] -= 1 self.socket_ptr_refcount[socket_ptr] -= 1
def regenerate(self) -> DeltaNodeLinkCache: def regenerate(self) -> DeltaNodeLinkCache:
"""Regenerates the cache from the internally-linked node tree. """Efficiently scans the internally referenced node tree to thoroughly update all attributes of this `NodeLinkCache`.
Notes: Notes:
- This is designed to run within the `update()` invocation of the node tree. This runs in a **very** hot loop, within the `update()` function of the node tree.
- This should be a very fast function, since it is called so much. Anytime anything happens in the node tree, `update()` (and therefore this method) is called.
Thus, performance is of the utmost importance.
Just a few microseconds too much may be amplified dozens of times over in practice, causing big stutters.
""" """
# Compute All NodeLink Pointers # Compute All NodeLink Pointers
## -> It can be very inefficient to do any full-scan of the node tree.
## -> However, simply extracting the pointer: link ends up being fast.
## -> This pattern seems to be the best we can do, efficiency-wise.
all_link_ptrs_as_links = { all_link_ptrs_as_links = {
link.as_pointer(): link for link in self._node_tree.links link.as_pointer(): link for link in self._node_tree.links
} }
all_link_ptrs = set(all_link_ptrs_as_links.keys()) all_link_ptrs = set(all_link_ptrs_as_links.keys())
# Compute Added/Removed Links # Compute Added/Removed Links
## -> In essence, we've created a 'diff' here.
## -> Set operations are fast, and expressive!
added_link_ptrs = all_link_ptrs - self.link_ptrs added_link_ptrs = all_link_ptrs - self.link_ptrs
removed_link_ptrs = self.link_ptrs - all_link_ptrs removed_link_ptrs = self.link_ptrs - all_link_ptrs
# Edge Case: 'from_socket' Reassignment # Edge Case: 'from_socket' Reassignment
## (Reverse engineered) When all: ## (Reverse Engineered) When all are true:
## - Created a new link between the same two nodes. ## - Created a new link between the same nodes as previous link.
## - Matching 'to_socket'. ## - Matching 'to_socket' as the previous link.
## - Non-matching 'from_socket' on the same node. ## - Non-matching 'from_socket', but on the same node.
## -> THEN the link_ptr will not change, but the from_socket ptr should. ## -> THEN the link_ptr will not change, but the from_socket ptr does.
if len(added_link_ptrs) == 0 and len(removed_link_ptrs) == 0: if not added_link_ptrs and not removed_link_ptrs:
# Find the Link w/Reassigned 'from_socket' PTR # Find the Link w/Reassigned 'from_socket' PTR
## A bit of a performance hit from the search, but it's an edge case. ## -> This isn't very fast, but the edge case isn't so common.
## -> Comprehensions are still quite optimized.
_link_ptr_as_from_socket_ptrs = { _link_ptr_as_from_socket_ptrs = {
link_ptr: ( link_ptr: (
from_socket_ptr, from_socket_ptr,
@ -149,9 +203,9 @@ class NodeLinkCache:
} }
# Completely Remove the Old Link (w/Reassigned 'from_socket') # Completely Remove the Old Link (w/Reassigned 'from_socket')
## This effectively reclassifies the edge case as a normal 're-add'. ## -> Casts the edge case to look like a typical 're-add'.
for link_ptr in _link_ptr_as_from_socket_ptrs: for link_ptr in _link_ptr_as_from_socket_ptrs:
log.info( log.debug(
'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s', 'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s',
link_ptr, link_ptr,
) )
@ -159,21 +213,25 @@ class NodeLinkCache:
self.remove_sockets_by_link_ptr(link_ptr) self.remove_sockets_by_link_ptr(link_ptr)
# Recompute Added/Removed Links # Recompute Added/Removed Links
## The algorithm will now detect an "added link". ## -> Guide the usual algorithm to detect an "added link".
added_link_ptrs = all_link_ptrs - self.link_ptrs added_link_ptrs = all_link_ptrs - self.link_ptrs
removed_link_ptrs = self.link_ptrs - all_link_ptrs removed_link_ptrs = self.link_ptrs - all_link_ptrs
# Shuffle Cache based on Change in Links # Delete Removed Links
## Remove Entries for Removed Pointers ## -> NOTE: We leave dangling socket information on purpose.
## -> This information will be used to ask for 'removal consent'.
## -> To truly remove, must call 'remove_socket_by_link_ptr' later.
for removed_link_ptr in removed_link_ptrs: for removed_link_ptr in removed_link_ptrs:
self.remove_link(removed_link_ptr) self.remove_link(removed_link_ptr)
## User must manually call 'remove_socket_by_link_ptr' later.
## For now, leave dangling socket information by-link.
# Add New Link Pointers # Create Added Links
## -> First, simply concatenate the added link pointers.
self.link_ptrs |= added_link_ptrs self.link_ptrs |= added_link_ptrs
for link_ptr in added_link_ptrs: for link_ptr in added_link_ptrs:
# Add Link PTR->REF # Create Pointer -> Reference Entry
## -> This allows us to efficiently access the link by-pointer.
## -> Doing so otherwise requires a full search.
## -> **If link is deleted w/o report, access will cause crash**.
new_link = all_link_ptrs_as_links[link_ptr] new_link = all_link_ptrs_as_links[link_ptr]
self.link_ptrs_as_links[link_ptr] = new_link self.link_ptrs_as_links[link_ptr] = new_link
@ -183,34 +241,69 @@ class NodeLinkCache:
to_socket = new_link.to_socket to_socket = new_link.to_socket
to_socket_ptr = to_socket.as_pointer() to_socket_ptr = to_socket.as_pointer()
# Add Socket PTR, PTR -> REF # Add Socket Information
for socket_ptr, bl_socket in zip( # noqa: B905 for socket_ptr, bl_socket in zip( # noqa: B905
[from_socket_ptr, to_socket_ptr], [from_socket_ptr, to_socket_ptr],
[from_socket, to_socket], [from_socket, to_socket],
): ):
# Increment RefCount of Socket PTR # RefCount > 0: Increment RefCount of Socket PTR
## This happens if another link also uses the same socket. ## This happens if another link also uses the same socket.
## 1. An output socket links to several inputs. ## 1. An output socket links to several inputs.
## 2. A multi-input socket links from several inputs. ## 2. A multi-input socket links from several inputs.
if socket_ptr in self.socket_ptr_refcount: if socket_ptr in self.socket_ptr_refcount:
self.socket_ptr_refcount[socket_ptr] += 1 self.socket_ptr_refcount[socket_ptr] += 1
# RefCount == 0: Create Socket Pointer w/Reference
## -> Also initialize the refcount for the socket pointer.
else: else:
## RefCount == 0: Add PTR, PTR -> REF
self.socket_ptrs.add(socket_ptr) self.socket_ptrs.add(socket_ptr)
self.socket_ptrs_as_sockets[socket_ptr] = bl_socket self.socket_ptrs_as_sockets[socket_ptr] = bl_socket
self.socket_ptr_refcount[socket_ptr] = 1 self.socket_ptr_refcount[socket_ptr] = 1
# Add Link PTR -> Socket PTR # Add Entry from Link Pointer -> Socket Pointer
self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr
self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr
return {'added': added_link_ptrs, 'removed': removed_link_ptrs} return {'added': added_link_ptrs, 'removed': removed_link_ptrs}
def update_validity(self) -> DeltaNodeLinkCache:
"""Query all cached links to determine whether they are valid."""
self.link_ptrs_invalid = {
link_ptr for link_ptr, link in self.link_ptrs_as_links if not link.is_valid
}
def report_validity(self, link_ptr: MemAddr, validity: bool) -> None:
"""Report a link as invalid."""
if validity and link_ptr in self.link_ptrs_invalid:
self.link_ptrs_invalid.remove(link_ptr)
elif not validity and link_ptr not in self.link_ptrs_invalid:
self.link_ptrs_invalid.add(link_ptr)
def set_validities(self) -> None:
"""Set the validity of links in the node tree according to the internal cache.
Validity doesn't need to be removed, as update() automatically cleans up by default.
"""
for link in [
link
for link_ptr, link in self.link_ptrs_as_links.items()
if link_ptr in self.link_ptrs_invalid
]:
if link.is_valid:
link.is_valid = False
#################### ####################
# - Node Tree Definition # - Node Tree Definition
#################### ####################
class MaxwellSimTree(bpy.types.NodeTree): class MaxwellSimTree(bpy.types.NodeTree):
"""Node tree containing a node-based program for design and analysis of Maxwell PDE simulations.
Attributes:
is_active: Whether the node tree should be considered to be in a usable state, capable of updating Blender data.
In general, only one `MaxwellSimTree` should be active at a time.
"""
bl_idname = ct.TreeType.MaxwellSim.value bl_idname = ct.TreeType.MaxwellSim.value
bl_label = 'Maxwell Sim Editor' bl_label = 'Maxwell Sim Editor'
bl_icon = ct.Icon.SimNodeEditor bl_icon = ct.Icon.SimNodeEditor
@ -219,63 +312,6 @@ class MaxwellSimTree(bpy.types.NodeTree):
default=True, default=True,
) )
####################
# - Lock Methods
####################
def unlock_all(self) -> None:
"""Unlock all nodes in the node tree, making them editable."""
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
for node in self.nodes:
if node.type in ['REROUTE', 'FRAME']:
continue
node.locked = False
for bl_socket in [*node.inputs, *node.outputs]:
bl_socket.locked = False
@contextlib.contextmanager
def replot(self) -> None:
self.is_currently_replotting = True
self.something_plotted = False
try:
yield
finally:
self.is_currently_replotting = False
if not self.something_plotted:
ManagedBLImage.hide_preview()
def report_show_plot(self, node: bpy.types.Node) -> None:
if hasattr(self, 'is_currently_replotting') and self.is_currently_replotting:
self.something_plotted = True
@contextlib.contextmanager
def repreview_all(self) -> None:
all_nodes_with_preview_active = {
node.instance_id: node
for node in self.nodes
if node.type not in ['REROUTE', 'FRAME'] and node.preview_active
}
self.is_currently_repreviewing = True
self.newly_previewed_nodes = {}
try:
yield
finally:
self.is_currently_repreviewing = False
for dangling_previewed_node in [
node
for node_instance_id, node in all_nodes_with_preview_active.items()
if node_instance_id not in self.newly_previewed_nodes
]:
dangling_previewed_node.preview_active = False
def report_show_preview(self, node: bpy.types.Node) -> None:
if (
hasattr(self, 'is_currently_repreviewing')
and self.is_currently_repreviewing
):
self.newly_previewed_nodes[node.instance_id] = node
#################### ####################
# - Init Methods # - Init Methods
#################### ####################
@ -290,7 +326,54 @@ class MaxwellSimTree(bpy.types.NodeTree):
self.node_link_cache = NodeLinkCache(self) self.node_link_cache = NodeLinkCache(self)
#################### ####################
# - Update Methods # - Lock Methods
####################
def unlock_all(self) -> None:
"""Unlock all nodes in the node tree, making them editable.
Notes:
All `MaxwellSimNode`s have a `.locked` attribute, which prevents the entire UI from being modified.
This method simply sets the `locked` attribute to `False` on all nodes.
"""
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
for node in self.nodes:
if node.type in ['REROUTE', 'FRAME']:
continue
# Unlock Node
if node.locked:
node.locked = False
# Unlock Node Sockets
for bl_socket in [*node.inputs, *node.outputs]:
if bl_socket.locked:
bl_socket.locked = False
####################
# - Link Update Methods
####################
def report_link_validity(self, link: bpy.types.NodeLink, validity: bool) -> None:
"""Report that a particular `NodeLink` should be considered to be either valid or invalid.
The `NodeLink.is_valid` attribute is generally (and automatically) used to indicate the detection of cycles in the node tree.
However, visually, it causes a very clear "error red" highlight to appear on the node link, which can extremely useful when determining the reasons behind unexpected outout.
Notes:
Run by `MaxwellSimSocket` when a link should be shown to be "invalid".
"""
## TODO: Doesn't quite work.
# log.debug(
# 'Reported Link Validity %s (is_valid=%s, from_socket=%s, to_socket=%s)',
# validity,
# link.is_valid,
# link.from_socket,
# link.to_socket,
# )
# self.node_link_cache.report_validity(link.as_pointer(), validity)
####################
# - Node Update Methods
#################### ####################
def on_node_removed(self, node: bpy.types.Node): def on_node_removed(self, node: bpy.types.Node):
"""Run by `MaxwellSimNode.free()` when a node is being removed. """Run by `MaxwellSimNode.free()` when a node is being removed.
@ -327,32 +410,36 @@ class MaxwellSimTree(bpy.types.NodeTree):
self.node_link_cache.remove_link(link_ptr) self.node_link_cache.remove_link(link_ptr)
self.node_link_cache.remove_sockets_by_link_ptr(link_ptr) self.node_link_cache.remove_sockets_by_link_ptr(link_ptr)
def update(self) -> None: def update(self) -> None: # noqa: PLR0912, C901
"""Monitors all changes to the node tree, potentially responding with appropriate callbacks. """Monitors all changes to the node tree, potentially responding with appropriate callbacks.
Notes: Notes:
- Run by Blender when "anything" changes in the node tree. - Run by Blender when "anything" changes in the node tree.
- Responds to node link changes with callbacks, with the help of a performant node link cache. - Responds to node link changes with callbacks, with the help of a performant node link cache.
""" """
# Perform Initial Load
## -> Presume update() is run before the first link is altered.
## -> Else, the first link of the session will not update caches.
## -> We still remain slightly unsure of the exact semantics.
## -> Therefore, self.on_load() is also called as a load_post handler.
if not hasattr(self, 'node_link_cache'):
self.on_load()
return
# Register Validity Updater
## -> They will be run after the update() method.
## -> Between update() and set_validities, all is_valid=True are cleared.
## -> Therefore, 'set_validities' only needs to set all is_valid=False.
bpy.app.timers.register(self.node_link_cache.set_validities)
# Ignore Updates
## -> Certain corrective processes require suppressing the next update.
## -> Otherwise, link corrections may trigger some nasty recursions.
if not hasattr(self, 'ignore_update'): if not hasattr(self, 'ignore_update'):
self.ignore_update = False self.ignore_update = False
if not hasattr(self, 'node_link_cache'): # Regenerate NodeLinkCache
self.on_load()
## We presume update() is run before the first link is altered.
## - Else, the first link of the session will not update caches.
## - We remain slightly unsure of the semantics.
## - Therefore, self.on_load() is also called as a load_post handler.
return
# Ignore Update
## Manually set to implement link corrections w/o recursion.
if self.ignore_update:
return
# Compute Changes to Node Links
delta_links = self.node_link_cache.regenerate() delta_links = self.node_link_cache.regenerate()
link_corrections = { link_corrections = {
'to_remove': [], 'to_remove': [],
'to_add': [], 'to_add': [],

View File

@ -358,6 +358,11 @@ class ExtractDataNode(base.MaxwellSimNode):
## -> Those string labels explain the integer as ex. Ex, Ey, Hy. ## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name) idx_labels = valid_monitor_attrs(sim_data, monitor_name)
# Extract Info
## -> We only need the output symbol.
## -> All labelled outputs have the same output SimSymbol.
info = extract_info(monitor_data, idx_labels[0])
# Generate FuncFlow Per Index Label # Generate FuncFlow Per Index Label
## -> We extract each XArray as an attribute of monitor_data. ## -> We extract each XArray as an attribute of monitor_data.
## -> We then bind its values into a unique func_flow. ## -> We then bind its values into a unique func_flow.
@ -377,7 +382,8 @@ class ExtractDataNode(base.MaxwellSimNode):
## -> Then, 'compose_within' lets us stack them along axis=0. ## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels! ## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within( return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
enclosing_func=lambda data: jnp.stack(data, axis=0) lambda data: jnp.stack(data, axis=0),
func_output=info.output,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -65,12 +65,12 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation FO = FilterOperation
return { return {
# Slice # Slice
FO.Slice: '=a[i:j]', FO.Slice: '≈a[v₁:v₂]',
FO.SliceIdx: '≈a[v₁:v₂]', FO.SliceIdx: '=a[i:j]',
# Pin # Pin
FO.PinLen1: 'pinₐ', FO.PinLen1: 'a[0] → a',
FO.Pin: 'pinₐ ≈v', FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'pinₐ =i', FO.PinIdx: 'a[i] → a',
# Reinterpret # Reinterpret
FO.Swap: 'a₁ ↔ a₂', FO.Swap: 'a₁ ↔ a₂',
}[value] }[value]
@ -517,6 +517,7 @@ class FilterMathNode(base.MaxwellSimNode):
return lazy_func.compose_within( return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple), operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args, enclosing_func_args=operation.func_args,
enclosing_func_output=info.output,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -547,22 +547,30 @@ class MapMathNode(base.MaxwellSimNode):
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
# Loaded
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
props={'operation'}, props={'operation'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': ct.FlowKind.Func, 'Expr': ct.FlowKind.Func,
}, },
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal: def compute_func(
operation = props['operation'] self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
expr = input_sockets['Expr'] expr = input_sockets['Expr']
output_info = output_sockets['Expr']
has_expr = not ct.FlowSignal.check(expr) has_expr = not ct.FlowSignal.check(expr)
has_output_info = not ct.FlowSignal.check(output_info)
operation = props['operation']
if has_expr and operation is not None: if has_expr and operation is not None:
return expr.compose_within( return expr.compose_within(
operation.jax_func, operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -146,7 +146,6 @@ class BinaryOperation(enum.StrEnum):
outl = info_l.output outl = info_l.output
outr = info_r.output outr = info_r.output
match (outl.shape_len, outr.shape_len): match (outl.shape_len, outr.shape_len):
# match (ol.shape_len, info_r.output.shape_len):
# Number | * # Number | *
## Number | Number ## Number | Number
case (0, 0): case (0, 0):
@ -154,15 +153,25 @@ class BinaryOperation(enum.StrEnum):
BO.Add, BO.Add,
BO.Sub, BO.Sub,
BO.Mul, BO.Mul,
BO.Div,
BO.Pow,
] ]
# Check Non-Zero Right Hand Side
## -> Obviously, we can't ever divide by zero.
## -> Sympy's assumptions system must always guarantee rhs != 0.
## -> If it can't, then we simply don't expose division.
## -> The is_zero assumption must be provided elsewhere.
## -> NOTE: This may prevent some valid uses of division.
## -> Watch out for "division is missing" bugs.
if info_r.output.is_nonzero:
ops.append(BO.Div)
if ( if (
info_l.output.physical_type == spux.PhysicalType.Length info_l.output.physical_type == spux.PhysicalType.Length
and info_l.output.unit == info_r.output.unit and info_l.output.unit == info_r.output.unit
): ):
ops += [BO.Atan2] ops += [BO.Atan2]
return ops
return [*ops, BO.Pow]
## Number | Vector ## Number | Vector
case (0, 1): case (0, 1):
@ -336,7 +345,13 @@ class BinaryOperation(enum.StrEnum):
# - InfoFlow Transform # - InfoFlow Transform
#################### ####################
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.""" """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
Warnings:
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
If not, bad things will happen.
"""
return info_l.operate_output( return info_l.operate_output(
info_r, info_r,
lambda a, b: self.sp_func([a, b]), lambda a, b: self.sp_func([a, b]),
@ -479,29 +494,35 @@ class OperateMathNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
# Loaded
props={'operation'}, props={'operation'},
input_sockets={'Expr L', 'Expr R'}, input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={ input_socket_kinds={
'Expr L': ct.FlowKind.Func, 'Expr L': ct.FlowKind.Func,
'Expr R': ct.FlowKind.Func, 'Expr R': ct.FlowKind.Func,
}, },
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compose_func(self, props: dict, input_sockets: dict): def compute_func(self, props, input_sockets, output_sockets):
operation = props['operation'] operation = props['operation']
if operation is None: if operation is None:
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
expr_l = input_sockets['Expr L'] expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R'] expr_r = input_sockets['Expr R']
output_info = output_sockets['Expr']
has_expr_l = not ct.FlowSignal.check(expr_l) has_expr_l = not ct.FlowSignal.check(expr_l)
has_expr_r = not ct.FlowSignal.check(expr_r) has_expr_r = not ct.FlowSignal.check(expr_r)
has_output_info = not ct.FlowSignal.check(output_info)
# Compute Jax Function # Compute Jax Function
## -> The operation enum directly provides the appropriate function. ## -> The operation enum directly provides the appropriate function.
if has_expr_l and has_expr_r: if has_expr_l and has_expr_r and has_output_info:
return (expr_l | expr_r).compose_within( return (expr_l | expr_r).compose_within(
enclosing_func=operation.jax_func, operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -520,6 +541,8 @@ class OperateMathNode(base.MaxwellSimNode):
}, },
) )
def compute_info(self, props, input_sockets) -> ct.InfoFlow: def compute_info(self, props, input_sockets) -> ct.InfoFlow:
BO = BinaryOperation
operation = props['operation'] operation = props['operation']
info_l = input_sockets['Expr L'] info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R'] info_r = input_sockets['Expr R']
@ -533,7 +556,7 @@ class OperateMathNode(base.MaxwellSimNode):
has_info_l has_info_l
and has_info_r and has_info_r
and operation is not None and operation is not None
and operation in BinaryOperation.by_infos(info_l, info_r) and operation in BO.by_infos(info_l, info_r)
): ):
return operation.transform_infos(info_l, info_r) return operation.transform_infos(info_l, info_r)

View File

@ -606,27 +606,32 @@ class TransformMathNode(base.MaxwellSimNode):
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}, 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
}, },
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal: def compute_func(
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
"""Transform the input `InfoFlow` depending on the transform operation.""" """Transform the input `InfoFlow` depending on the transform operation."""
TO = TransformOperation TO = TransformOperation
operation = props['operation']
lazy_func = input_sockets['Expr'][ct.FlowKind.Func] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
output_info = output_sockets['Expr']
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_lazy_func = not ct.FlowSignal.check(lazy_func) has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_output_info = not ct.FlowSignal.check(output_info)
if operation is not None and has_lazy_func and has_info: operation = props['operation']
# Retrieve Properties if operation is not None and has_lazy_func and has_info and has_output_info:
dim = props['dim'] dim = props['dim']
# Match Pattern by Operation
match operation: match operation:
case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D: case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D:
if dim is not None and info.has_idx_discrete(dim): if dim is not None and info.has_idx_discrete(dim):
return lazy_func.compose_within( return lazy_func.compose_within(
operation.jax_func(axis=info.dim_axis(dim)), operation.jax_func(axis=info.dim_axis(dim)),
enclosing_func_output=output_info.output,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -634,6 +639,7 @@ class TransformMathNode(base.MaxwellSimNode):
case _: case _:
return lazy_func.compose_within( return lazy_func.compose_within(
operation.jax_func(), operation.jax_func(),
enclosing_func_output=output_info.output,
supports_jax=True, supports_jax=True,
) )

View File

@ -406,7 +406,7 @@ class VizNode(base.MaxwellSimNode):
}, },
all_loose_input_sockets=True, all_loose_input_sockets=True,
) )
def compute_dummy_value(self, props, input_sockets, loose_input_sockets): def compute_previews(self, props, input_sockets, loose_input_sockets):
"""Needed for the plot to regenerate in the viewer.""" """Needed for the plot to regenerate in the viewer."""
return ct.PreviewsFlow(bl_image_name=props['sim_node_name']) return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
@ -433,7 +433,7 @@ class VizNode(base.MaxwellSimNode):
def on_show_plot( def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets self, managed_objs, props, input_sockets, loose_input_sockets
) -> None: ) -> None:
log.critical('Show Plot (too many times)') log.debug('Show Plot')
lazy_func = input_sockets['Expr'][ct.FlowKind.Func] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params] params = input_sockets['Expr'][ct.FlowKind.Params]
@ -456,6 +456,7 @@ class VizNode(base.MaxwellSimNode):
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
}, },
) )
## TODO: CACHE entries that don't change, PLEASEEE
# Match Viz Type & Perform Visualization # Match Viz Type & Perform Visualization
## -> Viz Target determines how to plot. ## -> Viz Target determines how to plot.

View File

@ -207,12 +207,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
stop_propagation=True, stop_propagation=True,
) )
def _on_sim_node_name_changed(self, props): def _on_sim_node_name_changed(self, props):
log.debug( # log.debug(
'Changed Sim Node Name of a "%s" to "%s" (self=%s)', # 'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
self.bl_idname, # self.bl_idname,
props['sim_node_name'], # props['sim_node_name'],
str(self), # str(self),
) # )
# (Re)Construct Managed Objects # (Re)Construct Managed Objects
## -> Due to 'prev_name', the new MObjs will be renamed on construction ## -> Due to 'prev_name', the new MObjs will be renamed on construction
@ -360,27 +360,48 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
#################### ####################
# - Socket Management # - Socket Management
#################### ####################
## TODO: Check for namespace collisions in sockets to prevent silent errors
def _prune_inactive_sockets(self): def _prune_inactive_sockets(self):
"""Remove all "inactive" sockets from the node. """Remove all inactive sockets from the node, while only updating sockets that can be non-destructively updated.
A socket is considered "inactive" when it shouldn't be defined (per `self.active_socket_defs), but is present nonetheless. The first step is easy: We determine, by-name, which sockets should no longer be defined, then remove them correctly.
The second step is harder: When new sockets have overlapping names, should they be removed, or should they merely have some properties updated?
Removing and re-adding the same socket is an accurate, generally robust approach, but it comes with a big caveat: **Existing node links will be cut**, even when it might semantically make sense to simply alter the socket's properties, keeping the links.
Different `bl_socket.socket_type`s can never be updated - they must be removed.
Otherwise, `SocketDef.compare(bl_socket)` allows us to granularly determine whether a particular `bl_socket` has changed with respect to the desired specification.
When the comparison is `False`, we can carefully utilize `SocketDef.init()` to re-initialize the socket, guaranteeing that the altered socket is up to the new specification.
""" """
node_tree = self.id_data node_tree = self.id_data
for direc in ['input', 'output']: for direc in ['input', 'output']:
all_bl_sockets = self._bl_sockets(direc) bl_sockets = self._bl_sockets(direc)
active_bl_socket_defs = self.active_socket_defs(direc) active_socket_defs = self.active_socket_defs(direc)
# Determine Sockets to Remove # Determine Sockets to Remove
## -> Name: If the existing socket name isn't "active".
## -> Type: If the existing socket_type != "active" SocketDef.
bl_sockets_to_remove = [ bl_sockets_to_remove = [
bl_socket bl_socket
for socket_name, bl_socket in all_bl_sockets.items() for socket_name, bl_socket in bl_sockets.items()
if socket_name not in active_bl_socket_defs if (
or socket_name socket_name not in active_socket_defs
in ( or bl_socket.socket_type
self.loose_input_sockets is not active_socket_defs[socket_name].socket_type
if direc == 'input' )
else self.loose_output_sockets ]
# Determine Sockets to Update
## -> Name: If the existing socket name is "active".
## -> Type: If the existing socket_type == "active" SocketDef.
## -> Compare: If the existing socket differs from the SocketDef.
bl_sockets_to_update = [
bl_socket
for socket_name, bl_socket in bl_sockets.items()
if (
socket_name in active_socket_defs
and bl_socket.socket_type
is active_socket_defs[socket_name].socket_type
and not active_socket_defs[socket_name].compare(bl_socket)
) )
] ]
@ -392,24 +413,25 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> The NodeLinkCache needs to be adjusted manually. ## -> The NodeLinkCache needs to be adjusted manually.
node_tree.on_node_socket_removed(bl_socket) node_tree.on_node_socket_removed(bl_socket)
# 2. Invalidate the input socket cache across all kinds. # 2. Perform the removal using Blender's API.
## -> Actually removes the socket.
bl_sockets.remove(bl_socket)
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available. ## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self._compute_input.invalidate( self._compute_input.invalidate(
input_socket_name=bl_socket_name, input_socket_name=bl_socket_name,
kind=..., kind=...,
unit_system=..., unit_system=...,
) )
# 3. Perform the removal using Blender's API.
## -> Actually removes the socket.
all_bl_sockets.remove(bl_socket)
if direc == 'input': if direc == 'input':
# 4. Run all trigger-only `on_value_changed` callbacks. # 4. Run all trigger-only `on_value_changed` callbacks.
## -> Runs any event methods that relied on the socket. ## -> Runs any event methods that relied on the socket.
## -> Only methods that don't **require** the socket. ## -> Only methods that don't **require** the socket.
## Trigger-Only: If method loads no socket data, it runs. ## Only Trigger: If method loads no socket data, it runs.
## `optional`: If method optional-loads socket, it runs. ## Optional: If method optional-loads socket, it runs.
triggered_event_methods = [ triggered_event_methods = [
event_method event_method
for event_method in self.filtered_event_methods_by_event( for event_method in self.filtered_event_methods_by_event(
@ -419,32 +441,52 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
not in event_method.callback_info.must_load_sockets not in event_method.callback_info.must_load_sockets
] ]
for event_method in triggered_event_methods: for event_method in triggered_event_methods:
log.critical(
'%s: Running %s',
self.sim_node_name,
str(event_method),
)
event_method(self) event_method(self)
# Update Sockets
for bl_socket in bl_sockets_to_update:
bl_socket_name = bl_socket.name
socket_def = active_socket_defs[bl_socket_name]
# 1. Pretend to Initialize for the First Time
## -> NOTE: The socket's caches will be completely regenerated.
## -> NOTE: A full FlowKind update will occur, but only one.
bl_socket.is_initializing = True
socket_def.preinit(bl_socket)
socket_def.init(bl_socket)
socket_def.postinit(bl_socket)
# 2. Re-Test Socket Capabilities
## -> Factors influencing CapabilitiesFlow may have changed.
## -> Therefore, we must re-test all link capabilities.
bl_socket.remove_invalidated_links()
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
self._compute_input.invalidate(
input_socket_name=bl_socket_name,
kind=...,
unit_system=...,
)
def _add_new_active_sockets(self): def _add_new_active_sockets(self):
"""Add and initialize all "active" sockets that aren't on the node. """Add and initialize all "active" sockets that aren't on the node.
Existing sockets within the given direction are not re-created. Existing sockets within the given direction are not re-created.
""" """
for direc in ['input', 'output']: for direc in ['input', 'output']:
all_bl_sockets = self._bl_sockets(direc) bl_sockets = self._bl_sockets(direc)
active_bl_socket_defs = self.active_socket_defs(direc) active_socket_defs = self.active_socket_defs(direc)
# Define BL Sockets # Define BL Sockets
created_sockets = {} created_sockets = {}
for socket_name, socket_def in active_bl_socket_defs.items(): for socket_name, socket_def in active_socket_defs.items():
# Skip Existing Sockets # Skip Existing Sockets
if socket_name in all_bl_sockets: if socket_name in bl_sockets:
continue continue
# Create BL Socket from Socket # Create BL Socket from Socket
## Set 'display_shape' from 'socket_shape' bl_sockets.new(
all_bl_sockets.new(
str(socket_def.socket_type.value), str(socket_def.socket_type.value),
socket_name, socket_name,
) )
@ -454,9 +496,9 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# Initialize Just-Created BL Sockets # Initialize Just-Created BL Sockets
for bl_socket_name, socket_def in created_sockets.items(): for bl_socket_name, socket_def in created_sockets.items():
socket_def.preinit(all_bl_sockets[bl_socket_name]) socket_def.preinit(bl_sockets[bl_socket_name])
socket_def.init(all_bl_sockets[bl_socket_name]) socket_def.init(bl_sockets[bl_socket_name])
socket_def.postinit(all_bl_sockets[bl_socket_name]) socket_def.postinit(bl_sockets[bl_socket_name])
# Invalidate Cached NoFlows # Invalidate Cached NoFlows
self._compute_input.invalidate( self._compute_input.invalidate(
@ -637,9 +679,10 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
lambda a, b: a | b, lambda a, b: a | b,
[ [
self._compute_input( self._compute_input(
socket, kind=ct.FlowKind.Previews, unit_system=None socket_name,
kind=ct.FlowKind.Previews,
) )
for socket in [bl_socket.name for bl_socket in self.inputs] for socket_name in [bl_socket.name for bl_socket in self.inputs]
], ],
ct.PreviewsFlow(), ct.PreviewsFlow(),
) )
@ -897,9 +940,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
) )
altered_socket_kinds[dep_out_sckname].add(dep_out_kind) altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
# Clear Output Socket Cache(s)
## -> We aggregate it manually, so it needs a special invl.
## -> See self.compute_output()
if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
for out_sckname in self.outputs.keys(): # noqa: SIM118
self.compute_output.invalidate(
output_socket_name=out_sckname,
kind=ct.FlowKind.Previews,
)
altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
# Run Triggered Event Methods # Run Triggered Event Methods
## -> A triggered event method may request to stop propagation. ## -> A triggered event method may request to stop propagation.
## -> A triggered event method may request to stop propagation.
stop_propagation = False stop_propagation = False
triggered_event_methods = self.filtered_event_methods_by_event( triggered_event_methods = self.filtered_event_methods_by_event(
event, (socket_name, prop_names, None) event, (socket_name, prop_names, None)

View File

@ -266,7 +266,7 @@ def event_decorator( # noqa: PLR0913
) )
# Loose Sockets # Loose Sockets
## Compute All Loose Input Sockets ## -> Determined by the active_kind of each loose input socket.
method_kw_args |= ( method_kw_args |= (
{ {
'loose_input_sockets': { 'loose_input_sockets': {

View File

@ -29,6 +29,8 @@ from ... import base, events
class ScientificConstantNode(base.MaxwellSimNode): class ScientificConstantNode(base.MaxwellSimNode):
"""A well-known constant usable as itself, or as a symbol."""
node_type = ct.NodeType.ScientificConstant node_type = ct.NodeType.ScientificConstant
bl_label = 'Scientific Constant' bl_label = 'Scientific Constant'
@ -88,6 +90,11 @@ class ScientificConstantNode(base.MaxwellSimNode):
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self):
if self.sci_constant_str:
return f'Const: {self.sci_constant_str}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
col.prop(self, self.blfields['sci_constant_str'], text='') col.prop(self, self.blfields['sci_constant_str'], text='')
@ -156,6 +163,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
props={'sci_constant', 'sci_constant_sym'}, props={'sci_constant', 'sci_constant_sym'},
) )
def compute_lazy_func(self, props) -> typ.Any: def compute_lazy_func(self, props) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant = props['sci_constant'] sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym'] sci_constant_sym = props['sci_constant_sym']
@ -165,6 +173,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax' [sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
), ),
func_args=[sci_constant_sym], func_args=[sci_constant_sym],
func_output=sci_constant_sym,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -175,6 +184,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
props={'sci_constant_sym'}, props={'sci_constant_sym'},
) )
def compute_info(self, props: dict) -> typ.Any: def compute_info(self, props: dict) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant_sym = props['sci_constant_sym'] sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None: if sci_constant_sym is not None:
@ -193,8 +203,12 @@ class ScientificConstantNode(base.MaxwellSimNode):
if sci_constant is not None and sci_constant_sym is not None: if sci_constant is not None and sci_constant_sym is not None:
return ct.ParamsFlow( return ct.ParamsFlow(
arg_targets=[sci_constant_sym], arg_targets=[sci_constant_sym],
func_args=[sci_constant], func_args=[sci_constant_sym.sp_symbol],
is_differentiable=True, symbols={sci_constant_sym},
).realize_partial(
{
sci_constant_sym: sci_constant,
}
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -216,10 +216,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
props={'symbol'}, props={'symbol'},
) )
def compute_lazy_func(self, props) -> typ.Any: def compute_lazy_func(self, props) -> typ.Any:
sp_sym = props['symbol'].sp_symbol sym = props['symbol']
return ct.FuncFlow( return ct.FuncFlow(
func=sp.lambdify(sp_sym, sp_sym, 'jax'), func=sp.lambdify(sym.sp_symbol_matsym, sym.sp_symbol_matsym, 'jax'),
func_args=[sp_sym], func_args=[sym],
func_output=sym,
supports_jax=True, supports_jax=True,
) )
@ -235,6 +236,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
) )
def compute_info(self, props) -> typ.Any: def compute_info(self, props) -> typ.Any:
return ct.InfoFlow( return ct.InfoFlow(
dims={props['symbol']: None},
output=props['symbol'], output=props['symbol'],
) )
@ -251,9 +253,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
arg_targets=[sym], arg_targets=[sym],
func_args=[sym.sp_symbol], func_args=[sym.sp_symbol],
symbols={sym}, symbols={sym},
is_differentiable=(
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
),
) )

View File

@ -198,9 +198,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
'Expr', 'Expr',
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
# Loaded # Loaded
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
input_sockets={'File Path'}, input_sockets={'File Path'},
) )
def compute_func(self, input_sockets) -> td.Simulation: def compute_func(self, props, input_sockets) -> td.Simulation:
"""Declare a lazy, composable function that returns the loaded data. """Declare a lazy, composable function that returns the loaded data.
Returns: Returns:
@ -209,6 +210,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
file_path = input_sockets['File Path'] file_path = input_sockets['File Path']
has_file_path = not ct.FlowSignal.check(file_path) has_file_path = not ct.FlowSignal.check(file_path)
func_output = sim_symbols.SimSymbol(
sym_name=props['output_name'],
mathtype=props['output_mathtype'],
physical_type=props['output_physical_type'],
unit=props['output_unit'],
)
if has_file_path and file_path is not None: if has_file_path and file_path is not None:
data_file_format = ct.DataFileFormat.from_path(file_path) data_file_format = ct.DataFileFormat.from_path(file_path)
if data_file_format is not None: if data_file_format is not None:
@ -217,13 +224,18 @@ class DataFileImporterNode(base.MaxwellSimNode):
if data_file_format.loader_is_jax_compatible: if data_file_format.loader_is_jax_compatible:
return ct.FuncFlow( return ct.FuncFlow(
func=lambda: data_file_format.loader(file_path), func=lambda: data_file_format.loader(file_path),
func_output=func_output,
supports_jax=True, supports_jax=True,
) )
# No Jax Compatibility: Eager Data Loading # No Jax Compatibility: Eager Data Loading
## -> Load the data now and bind it. ## -> Load the data now and bind it.
data = data_file_format.loader(file_path) data = data_file_format.loader(file_path)
return ct.FuncFlow(func=lambda: data, supports_jax=True) return ct.FuncFlow(
func=lambda: data,
func_output=func_output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -86,22 +86,45 @@ class ViewerNode(base.MaxwellSimNode):
# - Properties: Computed FlowKinds # - Properties: Computed FlowKinds
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name='Any', # Trigger
prop_name='console_print_kind',
# Loaded
props={'auto_expr', 'console_print_kind'},
) )
def on_input_changed(self) -> None: def on_print_kind_changed(self, props) -> None:
self.inputs['Any'].active_kind = props['console_print_kind']
if props['auto_expr']:
setattr(
self,
'input_' + props['console_print_kind'].property_name,
bl_cache.Signal.InvalidateCache,
)
@events.on_value_changed(
# Trugger
socket_name='Any',
# Loaded
props={'auto_expr', 'console_print_kind'},
)
def on_input_changed(self, props) -> None:
"""Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed. """Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed.
Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s. Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s.
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time. This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
""" """
for flow_kind in list(ct.FlowKind): # Invalidate PreviewsFlow
flow = self.get_flow(
flow_kind, always_load=flow_kind is ct.FlowKind.Previews
)
if flow is not None:
setattr( setattr(
self, self,
'input_' + flow_kind.property_name, 'input_' + ct.FlowKind.Previews.property_name,
bl_cache.Signal.InvalidateCache,
)
# Invalidate PreviewsFlow
if props['auto_expr']:
setattr(
self,
'input_' + props['console_print_kind'].property_name,
bl_cache.Signal.InvalidateCache, bl_cache.Signal.InvalidateCache,
) )

View File

@ -14,8 +14,10 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import functools
import typing as typ import typing as typ
import bpy
import sympy as sp import sympy as sp
from blender_maxwell.utils import bl_cache from blender_maxwell.utils import bl_cache
@ -26,6 +28,8 @@ from .. import base, events
class CombineNode(base.MaxwellSimNode): class CombineNode(base.MaxwellSimNode):
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
node_type = ct.NodeType.Combine node_type = ct.NodeType.Combine
bl_label = 'Combine' bl_label = 'Combine'
@ -33,112 +37,222 @@ class CombineNode(base.MaxwellSimNode):
# - Sockets # - Sockets
#################### ####################
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Maxwell Sources': {}, 'Sources': {},
'Maxwell Structures': {}, 'Structures': {},
'Maxwell Monitors': {}, 'Monitors': {},
} }
output_socket_sets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Maxwell Sources': { 'Sources': {
'Sources': sockets.MaxwellSourceSocketDef( 'Sources': sockets.MaxwellSourceSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
}, },
'Maxwell Structures': { 'Structures': {
'Structures': sockets.MaxwellStructureSocketDef( 'Structures': sockets.MaxwellStructureSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
}, },
'Maxwell Monitors': { 'Monitors': {
'Monitors': sockets.MaxwellMonitorSocketDef( 'Monitors': sockets.MaxwellMonitorSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
}, },
} }
#################### ####################
# - Draw # - Properties
#################### ####################
amount: int = bl_cache.BLField(2, abs_min=1, prop_ui=True) concatenate_first: bool = bl_cache.BLField(False)
value_or_func: ct.FlowKind = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_func(),
)
def _value_or_func(self):
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
]
#################### ####################
# - Draw # - Draw
#################### ####################
def draw_props(self, context, layout): def draw_props(self, _, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['amount'], text='') layout.prop(self, self.blfields['value_or_func'], text='')
if self.value_or_func is ct.FlowKind.Value:
layout.prop(
self,
self.blfields['concatenate_first'],
text='Concatenate',
toggle=True,
)
#################### ####################
# - Events # - Events
#################### ####################
@events.on_value_changed( @events.on_value_changed(
# Trigger any_loose_input_socket=True,
prop_name={'active_socket_set', 'amount'}, prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
props={'active_socket_set', 'amount'},
run_on_init=True, run_on_init=True,
# Loaded
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
) )
def on_inputs_changed(self, props): def on_inputs_changed(self, props) -> None:
if props['active_socket_set'] == 'Maxwell Sources': """Always create one extra loose input socket."""
if ( active_socket_set = props['active_socket_set']
not self.loose_input_sockets
or not next(iter(self.loose_input_sockets)).startswith('Source') # Deduce SocketDef
or len(self.loose_input_sockets) != props['amount'] ## -> Cheat by retrieving the class from the output sockets.
): SocketDef = self.output_socket_sets[active_socket_set][
self.loose_input_sockets = { active_socket_set
f'Source #{i}': sockets.MaxwellSourceSocketDef() ].__class__
for i in range(props['amount'])
} # Deduce Current "Filled"
## -> The first linked socket from the end bounds the "filled" region.
## -> The length of that region, plus one, will be the new amount.
reverse_linked_idxs = [
i
for i, bl_socket in enumerate(reversed(self.inputs.values()))
if bl_socket.is_linked
]
current_filled = len(self.inputs) - (
reverse_linked_idxs[0] if reverse_linked_idxs else len(self.inputs)
)
new_amount = current_filled + 1
# Deduce SocketDef | Current Amount
concatenate_first = props['concatenate_first']
flow_kind = props['value_or_func']
elif props['active_socket_set'] == 'Maxwell Structures':
if (
not self.loose_input_sockets
or not next(iter(self.loose_input_sockets)).startswith('Structure')
or len(self.loose_input_sockets) != props['amount']
):
self.loose_input_sockets = { self.loose_input_sockets = {
f'Structure #{i}': sockets.MaxwellStructureSocketDef() '#0': SocketDef(
for i in range(props['amount']) active_kind=flow_kind
} if flow_kind is ct.FlowKind.Func or not concatenate_first
elif props['active_socket_set'] == 'Maxwell Monitors': else ct.FlowKind.Array
if ( )
not self.loose_input_sockets } | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
or not next(iter(self.loose_input_sockets)).startswith('Monitor')
or len(self.loose_input_sockets) != props['amount']
):
self.loose_input_sockets = {
f'Monitor #{i}': sockets.MaxwellMonitorSocketDef()
for i in range(props['amount'])
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
#################### ####################
# - Output Socket Computation # - FlowKind.Array|Func
####################
def compute_combined(
self,
loose_input_sockets,
input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
If there is no output, or the flows aren't compatible, return `FlowPending`.
"""
match (input_flow_kind, output_flow_kind):
case (ct.FlowKind.Value, ct.FlowKind.Array):
value_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
]
if value_flows:
return value_flows
return ct.FlowSignal.FlowPending
case (ct.FlowKind.Func, ct.FlowKind.Func):
func_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
]
if func_flows:
return functools.reduce(
lambda a, b: a | b,
func_flows,
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
####################
# - Output: Sources
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Sources', 'Sources',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
all_loose_input_sockets=True, all_loose_input_sockets=True,
props={'amount'}, props={'value_or_func'},
)
def compute_sources_array(
self, props, loose_input_sockets
) -> list[typ.Any] | ct.FlowSignal:
"""Compute sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
) )
def compute_sources(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Source #{i}'] for i in range(props['amount'])]
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
)
####################
# - Output: Structures
####################
@events.computes_output_socket( @events.computes_output_socket(
'Structures', 'Structures',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
all_loose_input_sockets=True, all_loose_input_sockets=True,
props={'amount'}, props={'value_or_func'},
)
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
) )
def compute_structures(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Structure #{i}'] for i in range(props['amount'])]
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
)
####################
# - Output: Monitors
####################
@events.computes_output_socket( @events.computes_output_socket(
'Monitors', 'Monitors',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
all_loose_input_sockets=True, all_loose_input_sockets=True,
props={'amount'}, props={'value_or_func'},
)
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
)
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
) )
def compute_monitors(self, loose_input_sockets, props) -> sp.Expr:
return [loose_input_sockets[f'Monitor #{i}'] for i in range(props['amount'])]
#################### ####################

View File

@ -14,17 +14,26 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `FDTDSimNode`."""
import typing as typ import typing as typ
import sympy as sp import bpy
import tidy3d as td import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
from ... import contracts as ct from ... import contracts as ct
from ... import sockets from ... import sockets
from .. import base, events from .. import base, events
log = logger.get(__name__)
class FDTDSimNode(base.MaxwellSimNode): class FDTDSimNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
node_type = ct.NodeType.FDTDSim node_type = ct.NodeType.FDTDSim
bl_label = 'FDTD Simulation' bl_label = 'FDTD Simulation'
@ -35,51 +44,255 @@ class FDTDSimNode(base.MaxwellSimNode):
'BCs': sockets.MaxwellBoundCondsSocketDef(), 'BCs': sockets.MaxwellBoundCondsSocketDef(),
'Domain': sockets.MaxwellSimDomainSocketDef(), 'Domain': sockets.MaxwellSimDomainSocketDef(),
'Sources': sockets.MaxwellSourceSocketDef( 'Sources': sockets.MaxwellSourceSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
'Structures': sockets.MaxwellStructureSocketDef( 'Structures': sockets.MaxwellStructureSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
'Monitors': sockets.MaxwellMonitorSocketDef( 'Monitors': sockets.MaxwellMonitorSocketDef(
is_list=True, active_kind=ct.FlowKind.Array,
), ),
} }
output_sockets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Sim': sockets.MaxwellFDTDSimSocketDef(), 'Single': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
},
'Batch': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array),
},
'Lazy': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
},
} }
#################### ####################
# - Output Socket Computation # - Properties
####################
differentiable: bool = bl_cache.BLField(False)
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(
self,
self.blfields['differentiable'],
text='Differentiable',
toggle=True,
)
####################
# - Events
####################
@events.on_value_changed(
# Trigger
socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
run_on_init=True,
# Loaded
props={'active_socket_set'},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def on_any_changed(self, props, output_sockets) -> None:
"""Create loose input sockets."""
params = output_sockets['Sim']
has_params = not ct.FlowSignal.check(params)
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
active_socket_set = props['active_socket_set']
if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols:
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
'active_kind': ct.FlowKind.Value,
'use_value_range_swapper': (
active_socket_set == 'Value'
),
}
)
)
for sym, expr_info in params.sym_expr_infos.items()
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
####################
# - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Sim', 'Sim',
kind=ct.FlowKind.Value, kind=ct.FlowKind.Value,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'}, input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={ input_socket_kinds={
'Sources': ct.FlowKind.Array, 'Sources': ct.FlowKind.Array,
'Structures': ct.FlowKind.Array, 'Structures': ct.FlowKind.Array,
'Domain': ct.FlowKind.Value,
'BCs': ct.FlowKind.Value,
'Monitors': ct.FlowKind.Array, 'Monitors': ct.FlowKind.Array,
}, },
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
) )
def compute_fdtd_sim(self, input_sockets: dict) -> sp.Expr: def compute_fdtd_sim_value(
if any(ct.FlowSignal.check(inp) for inp in input_sockets): self, props, input_sockets, output_sockets
return ct.FlowSignal.FlowPending ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable."""
sim_domain = input_sockets['Domain'] sim_domain = input_sockets['Domain']
sources = input_sockets['Sources'] sources = input_sockets['Sources']
structures = input_sockets['Structures'] structures = input_sockets['Structures']
bounds = input_sockets['BCs'] bounds = input_sockets['BCs']
monitors = input_sockets['Monitors'] monitors = input_sockets['Monitors']
output_params = output_sockets['Sim']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
has_output_params = not ct.FlowSignal.check(output_params)
differentiable = props['differentiable']
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
and has_output_params
and not differentiable
):
return td.Simulation( return td.Simulation(
**sim_domain, **sim_domain,
structures=structures,
sources=sources, sources=sources,
monitors=monitors, structures=structures,
boundary_spec=bounds, boundary_spec=bounds,
monitors=monitors,
) )
## TODO: Visualize the boundary conditions on top of the sim domain return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Func,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'Sources': ct.FlowKind.Func,
'Structures': ct.FlowKind.Func,
'Monitors': ct.FlowKind.Func,
},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def compute_fdtd_sim_func(
self, props, input_sockets, output_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single simulation, given that all inputs are non-symbolic."""
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
bounds = input_sockets['BCs']
monitors = input_sockets['Monitors']
output_params = output_sockets['Sim']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
and has_output_params
):
differentiable = props['differentiable']
if differentiable:
return (
sim_domain | sources | structures | bounds | monitors
).compose_within(
enclosing_func=lambda els: tdadj.JaxSimulation(
**els[0],
sources=els[1],
structures=els[2]['static'],
input_structures=els[2]['differentiable'],
boundary_spec=els[3],
monitors=els[4]['static'],
output_monitors=els[4]['differentiable'],
),
supports_jax=True,
)
return (
sim_domain | sources | structures | bounds | monitors
).compose_within(
enclosing_func=lambda els: td.Simulation(
**els[0],
sources=els[1],
structures=els[2],
boundary_spec=els[3],
monitors=els[4],
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Params,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'Sources': ct.FlowKind.Params,
'Structures': ct.FlowKind.Params,
'Monitors': ct.FlowKind.Params,
},
)
def compute_fdtd_sim_params(
self, props, input_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single simulation, given that all inputs are non-symbolic."""
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
bounds = input_sockets['BCs']
monitors = input_sockets['Monitors']
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_bounds = not ct.FlowSignal.check(bounds)
has_monitors = not ct.FlowSignal.check(monitors)
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
):
# Determine Differentiable Match
## -> 'structures' is diff when **any** are diff.
## -> 'monitors' is also diff when **any** are diff.
## -> Only parameters through diff structs can be diff'ed by.
## -> Similarly, only diff monitors will have gradients computed.
return sim_domain | sources | structures | bounds | monitors
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `SimDomainNode`."""
import typing as typ import typing as typ
import sympy as sp import sympy as sp
@ -31,6 +33,8 @@ log = logger.get(__name__)
class SimDomainNode(base.MaxwellSimNode): class SimDomainNode(base.MaxwellSimNode):
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
node_type = ct.NodeType.SimDomain node_type = ct.NodeType.SimDomain
bl_label = 'Sim Domain' bl_label = 'Sim Domain'
use_sim_node_name = True use_sim_node_name = True
@ -69,26 +73,109 @@ class SimDomainNode(base.MaxwellSimNode):
} }
#################### ####################
# - Outputs # - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Domain', 'Domain',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Domain'},
output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Domain'][ct.FlowKind.Func]
output_params = output_sockets['Domain'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Func,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'}, input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, input_socket_kinds={
scale_input_sockets={ 'Duration': ct.FlowKind.Func,
'Duration': 'Tidy3DUnits', 'Center': ct.FlowKind.Func,
'Center': 'Tidy3DUnits', 'Size': ct.FlowKind.Func,
'Size': 'Tidy3DUnits', 'Grid': ct.FlowKind.Func,
'Ambient Medium': ct.FlowKind.Func,
}, },
) )
def compute_domain(self, input_sockets, unit_systems) -> sp.Expr: def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
return { """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
'run_time': input_sockets['Duration'], duration = input_sockets['Duration']
'center': input_sockets['Center'], center = input_sockets['Center']
'size': input_sockets['Size'], size = input_sockets['Size']
'grid_spec': input_sockets['Grid'], grid = input_sockets['Grid']
'medium': input_sockets['Ambient Medium'], medium = input_sockets['Ambient Medium']
}
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return (
duration.scale_to_unit_system(ct.UNITS_TIDY3D)
| center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| grid
| medium
).compose_within(
enclosing_func=lambda els: {
'run_time': els[0],
'center': tuple(els[1].flatten()),
'size': tuple(els[2].flatten()),
'grid_spec': els[3],
'medium': els[4],
},
supports_jax=False,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
'Duration': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Size': ct.FlowKind.Params,
'Grid': ct.FlowKind.Params,
'Ambient Medium': ct.FlowKind.Params,
},
)
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
size = input_sockets['Size']
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return duration | center | size | grid | medium
return ct.FlowSignal.FlowPending
#################### ####################
# - Preview # - Preview
@ -100,37 +187,39 @@ class SimDomainNode(base.MaxwellSimNode):
props={'sim_node_name'}, props={'sim_node_name'},
) )
def compute_previews(self, props): def compute_previews(self, props):
"""Mark the managed preview object for preview when `Domain` is linked to a viewer."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']}) return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed( @events.on_value_changed(
## Trigger # Trigger
socket_name={'Center', 'Size'}, socket_name={'Center', 'Size'},
run_on_init=True, run_on_init=True,
# Loaded # Loaded
input_sockets={'Center', 'Size'}, input_sockets={'Center', 'Size'},
managed_objs={'modifier'}, managed_objs={'modifier'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, output_sockets={'Domain'},
scale_input_sockets={ output_socket_kinds={'Domain': ct.FlowKind.Params},
'Center': 'BlenderUnits',
},
) )
def on_input_changed( def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
self, """Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
managed_objs, output_params = output_sockets['Domain']
input_sockets, center = input_sockets['Center']
unit_systems,
): has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier # Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier( managed_objs['modifier'].bl_modifier(
'NODES', 'NODES',
{ {
'node_group': import_geonodes(GeoNodes.SimulationSimDomain), 'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
'unit_system': unit_systems['BlenderUnits'], 'unit_system': ct.UNITS_BLENDER,
'inputs': { 'inputs': {
'Size': input_sockets['Size'], 'Size': input_sockets['Size'],
}, },
}, },
location=input_sockets['Center'], location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
) )

View File

@ -71,35 +71,129 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['pol_axis'], expand=True) layout.prop(self, self.blfields['pol_axis'], expand=True)
#################### ####################
# - Outputs # - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Source', 'Source',
input_sockets={'Temporal Shape', 'Center', 'Interpolate'}, # Loaded
props={'pol_axis'}, props={'pol_axis'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
scale_input_sockets={ output_sockets={'Source'},
'Center': 'Tidy3DUnits', output_socket_kinds={'Source': ct.FlowKind.Params},
},
) )
def compute_source( def compute_source_value(
self, self, input_sockets, props, output_sockets
input_sockets: dict[str, typ.Any], ) -> td.PointDipole | ct.FlowSignal:
props: dict[str, typ.Any], """Compute the point dipole source, given that all inputs are non-symbolic."""
unit_systems: dict, temporal_shape = input_sockets['Temporal Shape']
) -> td.PointDipole: center = input_sockets['Center']
interpolate = input_sockets['Interpolate']
output_params = output_sockets['Source']
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_center = not ct.FlowSignal.check(center)
has_interpolate = not ct.FlowSignal.check(interpolate)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_temporal_shape
and has_center
and has_interpolate
and has_output_params
and not output_params.symbols
):
pol_axis = { pol_axis = {
ct.SimSpaceAxis.X: 'Ex', ct.SimSpaceAxis.X: 'Ex',
ct.SimSpaceAxis.Y: 'Ey', ct.SimSpaceAxis.Y: 'Ey',
ct.SimSpaceAxis.Z: 'Ez', ct.SimSpaceAxis.Z: 'Ez',
}[props['pol_axis']] }[props['pol_axis']]
## TODO: Need Hx, Hy, Hz too?
return td.PointDipole( return td.PointDipole(
center=input_sockets['Center'], center=spux.convert_to_unit_system(center, ct.UNITS_TIDY3D),
source_time=input_sockets['Temporal Shape'], source_time=temporal_shape,
interpolate=input_sockets['Interpolate'], interpolate=interpolate,
polarization=pol_axis, polarization=pol_axis,
) )
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Func,
# Loaded
props={'pol_axis'},
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Interpolate': ct.FlowKind.Func,
},
output_sockets={'Source'},
output_socket_kinds={'Source': ct.FlowKind.Params},
)
def compute_source_func(self, props, input_sockets, output_sockets) -> td.Box:
"""Compute a lazy function for the point dipole source."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
output_params = output_sockets['Source']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_interpolate = not ct.FlowSignal.check(interpolate)
has_output_params = not ct.FlowSignal.check(output_params)
if has_temporal_shape and has_center and has_interpolate and has_output_params:
pol_axis = {
ct.SimSpaceAxis.X: 'Ex',
ct.SimSpaceAxis.Y: 'Ey',
ct.SimSpaceAxis.Z: 'Ez',
}[props['pol_axis']]
## TODO: Need Hx, Hy, Hz too?
return (center | temporal_shape | interpolate).compose_within(
enclosing_func=lambda els: td.PointDipole(
center=els[0],
source_time=els[1],
interpolate=els[2],
polarization=pol_axis,
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Interpolate': ct.FlowKind.Params,
},
)
def compute_params(
self,
input_sockets,
) -> td.PointDipole | ct.FlowSignal:
"""Compute the point dipole source, given that all inputs are non-symbolic."""
temporal_shape = input_sockets['Temporal Shape']
center = input_sockets['Center']
interpolate = input_sockets['Interpolate']
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_center = not ct.FlowSignal.check(center)
has_interpolate = not ct.FlowSignal.check(interpolate)
if has_temporal_shape and has_center and has_interpolate:
return temporal_shape | center | interpolate
return ct.FlowSignal.FlowPending
#################### ####################
# - Preview # - Preview

View File

@ -16,15 +16,19 @@
"""Implements the `TemporalShapeNode`.""" """Implements the `TemporalShapeNode`."""
import enum
import typing as typ import typing as typ
import bpy import bpy
import numpy as np
import sympy as sp import sympy as sp
import sympy.physics.units as spu import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from ... import contracts as ct from ... import contracts as ct
from ... import managed_objs, sockets from ... import managed_objs, sockets
@ -33,14 +37,10 @@ from .. import base, events
log = logger.get(__name__) log = logger.get(__name__)
_max_e_socket_def = sockets.ExprSocketDef( # Select Default Time Unit for Envelope
mathtype=spux.MathType.Complex, ## -> Chosen to align with the default envelope_time_unit.
physical_type=spux.PhysicalType.EField, ## -> This causes it to be correct from the start.
default_value=1 + 0j, t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
)
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
t_ps = sim_symbols.t(spu.picosecond)
class TemporalShapeNode(base.MaxwellSimNode): class TemporalShapeNode(base.MaxwellSimNode):
@ -63,17 +63,18 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_unit=spux.THz, default_unit=spux.THz,
default_value=200, default_value=200,
), ),
'max E': sockets.ExprSocketDef(
mathtype=spux.MathType.Complex,
physical_type=spux.PhysicalType.EField,
default_value=1 + 0j,
),
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
} }
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Pulse': { 'Pulse': {
'max E': _max_e_socket_def,
'Offset Time': _offset_socket_def,
'Remove DC': sockets.BoolSocketDef(default_value=True), 'Remove DC': sockets.BoolSocketDef(default_value=True),
}, },
'Constant': { 'Constant': {},
'max E': _max_e_socket_def,
'Offset Time': _offset_socket_def,
},
'Symbolic': { 'Symbolic': {
't Range': sockets.ExprSocketDef( 't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range, active_kind=ct.FlowKind.Range,
@ -84,8 +85,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_steps=100, default_steps=100,
), ),
'Envelope': sockets.ExprSocketDef( 'Envelope': sockets.ExprSocketDef(
default_symbols=[t_ps], default_symbols=[t_def],
default_value=10 * t_ps.sp_symbol, default_value=10 * t_def.sp_symbol,
), ),
}, },
} }
@ -98,6 +99,55 @@ class TemporalShapeNode(base.MaxwellSimNode):
'plot': managed_objs.ManagedBLImage, 'plot': managed_objs.ManagedBLImage,
} }
####################
# - Properties
####################
active_envelope_time_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_time_units(),
)
def search_time_units(self) -> list[ct.BLEnumElement]:
"""Compute all valid time units."""
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
]
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
def envelope_time_unit(self) -> spux.Unit | None:
"""Gets the current active unit for the envelope time symbol.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_envelope_time_unit is not None:
return spux.unit_str_to_unit(self.active_envelope_time_unit)
return None
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
if (
self.active_socket_set == 'Symbolic'
and self.inputs.get('Envelope')
and not self.inputs['Envelope'].is_linked
):
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Envelope Time Unit')
row = layout.row()
row.prop(
self,
self.blfields['active_envelope_time_unit'],
text='',
toggle=True,
)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Symbolic': if self.active_socket_set != 'Symbolic':
box = layout.box() box = layout.box()
@ -118,10 +168,53 @@ class TemporalShapeNode(base.MaxwellSimNode):
col.label(text='1 / 2π·σ(𝑓)') col.label(text='1 / 2π·σ(𝑓)')
#################### ####################
# - FlowKind: Value # - Events
####################
@events.on_value_changed(
# Trigger
prop_name={'active_socket_set', 'envelope_time_unit'},
# Loaded
props={'active_socket_set', 'envelope_time_unit'},
)
def on_envelope_time_unit_changed(self, props) -> None:
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
active_socket_set = props['active_socket_set']
envelope_time_unit = props['envelope_time_unit']
if active_socket_set == 'Symbolic':
bl_socket = self.inputs['Envelope']
wanted_t_sym = sim_symbols.t(envelope_time_unit)
if not bl_socket.symbols or bl_socket.symbols[0] != wanted_t_sym:
bl_socket.symbols = [wanted_t_sym]
####################
# - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Temporal Shape', 'Temporal Shape',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Temporal Shape'},
output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute a single temporal shape."""
output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
####################
# - FlowKind: Func
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Func,
# Loaded # Loaded
props={'active_socket_set'}, props={'active_socket_set'},
input_sockets={ input_sockets={
@ -134,60 +227,178 @@ class TemporalShapeNode(base.MaxwellSimNode):
'Envelope', 'Envelope',
}, },
input_socket_kinds={ input_socket_kinds={
't Range': ct.FlowKind.Range, 'max E': ct.FlowKind.Func,
'Envelope': ct.FlowKind.Func, 'μ Freq': ct.FlowKind.Func,
}, 'σ Freq': ct.FlowKind.Func,
input_sockets_optional={ 'Offset Time': ct.FlowKind.Func,
'max E': True, 'Remove DC': ct.FlowKind.Value,
'Offset Time': True, 't Range': ct.FlowKind.Func,
'Remove DC': True, 'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
't Range': True,
'Envelope': True,
},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'max E': 'Tidy3DUnits',
'μ Freq': 'Tidy3DUnits',
'σ Freq': 'Tidy3DUnits',
't Range': 'Tidy3DUnits',
'Offset Time': 'Tidy3DUnits',
}, },
) )
def compute_temporal_shape( def compute_temporal_shape_func(
self, props, input_sockets, unit_systems self,
props,
input_sockets,
) -> td.GaussianPulse: ) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_func = (
max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
| mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| offset ## Already unitless
)
match props['active_socket_set']: match props['active_socket_set']:
case 'Pulse': case 'Pulse':
return td.GaussianPulse( remove_dc = input_sockets['Remove DC']
amplitude=sp.re(input_sockets['max E']),
phase=sp.im(input_sockets['max E']), has_remove_dc = not ct.FlowSignal.check(remove_dc)
freq0=input_sockets['μ Freq'],
fwidth=input_sockets['σ Freq'], if has_remove_dc:
offset=input_sockets['Offset Time'], return common_func.compose_within(
remove_dc_component=input_sockets['Remove DC'], lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
remove_dc_component=remove_dc,
),
) )
case 'Constant': case 'Constant':
return td.ContinuousWave( return common_func.compose_within(
amplitude=sp.re(input_sockets['max E']), lambda els: td.GaussianPulse(
phase=sp.im(input_sockets['max E']), amplitude=complex(els[0]).real,
freq0=input_sockets['μ Freq'], phase=complex(els[0]).imag,
fwidth=input_sockets['σ Freq'], freq0=els[1],
offset=input_sockets['Offset Time'], fwidth=els[2],
offset=els[3],
),
) )
case 'Symbolic': case 'Symbolic':
lzrange = input_sockets['t Range'] t_range = input_sockets['t Range']
envelope_ps = input_sockets['Envelope'].func_jax envelope = input_sockets['Envelope'][ct.FlowKind.Func]
envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
return td.CustomSourceTime.from_values( has_t_range = not ct.FlowSignal.check(t_range)
freq0=input_sockets['μ Freq'], has_envelope = not ct.FlowSignal.check(envelope)
fwidth=input_sockets['σ Freq'], has_envelope_params = not ct.FlowSignal.check(envelope_params)
values=envelope_ps(
lzrange.rescale_to_unit(spu.ps).realize_array.values if (
), has_t_range
dt=input_sockets['t Range'].realize_step_size(), and has_envelope
and has_envelope_params
and len(envelope_params.symbols) == 1
## TODO: Allow unrealized envelope symbols
and any(
sym.physical_type is spux.PhysicalType.Time
for sym in envelope_params.symbols
) )
):
envelope_time_unit = next(
sym.unit
for sym in envelope_params.symbols
if sym.physical_type is spux.PhysicalType.Time
)
# Deduce Partially Realized Envelope Function
## -> We need a pure-numerical function w/pre-realized stuff baked in.
## -> 'realize_partial' does this for us.
envelope_realizer = envelope.realize_partial(envelope_params)
# Compose w/Envelope Function
## -> First, the numerical time values must be converted.
## -> This ensures that the raw array is compatible w/the envelope.
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
## -> Because of the checks, we've guaranteed that all this is correct.
return (
common_func ## 1 | freq0, 2 | fwidth, 3 | offset
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
| t_range.scale_to_unit(envelope_time_unit).compose_within(
lambda t: envelope_realizer(t)
) ## 5
).compose_within(
lambda els: td.CustomSourceTime(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
source_time_dataset=td_TimeDataset(
values=td_TimeDataArray(
els[5], coords={'t': np.array(els[4])}
)
),
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind: Params
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Params,
# Loaded
props={'active_socket_set', 'envelope_time_unit'},
input_sockets={
'max E',
'μ Freq',
'σ Freq',
'Offset Time',
't Range',
},
input_socket_kinds={
'max E': ct.FlowKind.Params,
'μ Freq': ct.FlowKind.Params,
'σ Freq': ct.FlowKind.Params,
'Offset Time': ct.FlowKind.Params,
't Range': ct.FlowKind.Params,
},
)
def compute_temporal_shape_params(
self,
props,
input_sockets,
) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_params = max_e | mean_freq | std_freq | offset
match props['active_socket_set']:
case 'Pulse' | 'Constant':
return common_params
case 'Symbolic':
t_range = input_sockets['t Range']
has_t_range = not ct.FlowSignal.check(t_range)
if has_t_range:
return common_params | t_range | t_range
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -88,28 +88,27 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure', 'Structure',
kind=ct.FlowKind.Value, kind=ct.FlowKind.Value,
# Loaded # Loaded
props={'differentiable'},
input_sockets={'Medium', 'Center', 'Size'}, input_sockets={'Medium', 'Center', 'Size'},
output_sockets={'Structure'}, output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params}, output_socket_kinds={'Structure': ct.FlowKind.Params},
) )
def compute_value(self, props, input_sockets, output_sockets) -> td.Box: def compute_value(self, input_sockets, output_sockets) -> td.Box:
output_params = output_sockets['Structure'] """Compute a single box structure object, given that all inputs are non-symbolic."""
center = input_sockets['Center'] center = input_sockets['Center']
size = input_sockets['Size'] size = input_sockets['Size']
medium = input_sockets['Medium'] medium = input_sockets['Medium']
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center) has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size) has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium) has_medium = not ct.FlowSignal.check(medium)
has_output_params = not ct.FlowSignal.check(output_params)
if ( if (
has_center has_center
and has_size and has_size
and has_medium and has_medium
and has_output_params and has_output_params
and not props['differentiable']
and not output_params.symbols and not output_params.symbols
): ):
return td.Structure( return td.Structure(
@ -138,7 +137,8 @@ class BoxStructureNode(base.MaxwellSimNode):
output_sockets={'Structure'}, output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params}, output_socket_kinds={'Structure': ct.FlowKind.Params},
) )
def compute_lazy_structure(self, props, input_sockets, output_sockets) -> td.Box: def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
output_params = output_sockets['Structure'] output_params = output_sockets['Structure']
center = input_sockets['Center'] center = input_sockets['Center']
size = input_sockets['Size'] size = input_sockets['Size']
@ -149,14 +149,8 @@ class BoxStructureNode(base.MaxwellSimNode):
has_size = not ct.FlowSignal.check(size) has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium) has_medium = not ct.FlowSignal.check(medium)
if has_output_params and has_center and has_size and has_medium:
differentiable = props['differentiable'] differentiable = props['differentiable']
if (
has_output_params
and has_center
and has_size
and has_medium
and differentiable == output_params.is_differentiable
):
if differentiable: if differentiable:
return (center | size | medium).compose_within( return (center | size | medium).compose_within(
enclosing_func=lambda els: tdadj.JaxStructure( enclosing_func=lambda els: tdadj.JaxStructure(
@ -169,6 +163,12 @@ class BoxStructureNode(base.MaxwellSimNode):
supports_jax=True, supports_jax=True,
) )
return (center | size | medium).compose_within( return (center | size | medium).compose_within(
## TODO: Unit conversion within the composed function??
## -- We do need Tidy3D to be given ex. micrometers in particular.
## -- But the previous numerical output might not be micrometers.
## -- There must be a way to add a conversion in, without strangeness.
## -- Ex. can compose_within() take a unit system?
## -- This would require
enclosing_func=lambda els: td.Structure( enclosing_func=lambda els: td.Structure(
geometry=td.Box( geometry=td.Box(
center=tuple(els[0].flatten()), center=tuple(els[0].flatten()),
@ -205,14 +205,8 @@ class BoxStructureNode(base.MaxwellSimNode):
has_medium = not ct.FlowSignal.check(medium) has_medium = not ct.FlowSignal.check(medium)
if has_center and has_size and has_medium: if has_center and has_size and has_medium:
if props['differentiable'] == (
center.is_differentiable
and size.is_differentiable
and medium.is_differentiable
):
return center | size | medium return center | size | medium
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
#################### ####################
# - Events: Preview # - Events: Preview
@ -226,6 +220,7 @@ class BoxStructureNode(base.MaxwellSimNode):
output_socket_kinds={'Structure': ct.FlowKind.Params}, output_socket_kinds={'Structure': ct.FlowKind.Params},
) )
def compute_previews(self, props, output_sockets): def compute_previews(self, props, output_sockets):
"""Mark the managed preview object when recursively linked to a viewer."""
output_params = output_sockets['Structure'] output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params) has_output_params = not ct.FlowSignal.check(output_params)
@ -245,10 +240,14 @@ class BoxStructureNode(base.MaxwellSimNode):
) )
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets): def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
output_params = output_sockets['Structure'] output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier
center = input_sockets['Center'] center = input_sockets['Center']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
## TODO: There are strategies for handling examples of symbol values.
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier( managed_objs['modifier'].bl_modifier(
'NODES', 'NODES',
{ {

View File

@ -43,17 +43,28 @@ class SocketDef(pyd.BaseModel, abc.ABC):
""" """
socket_type: ct.SocketType socket_type: ct.SocketType
active_kind: typ.Literal[
ct.FlowKind.Value,
ct.FlowKind.Array,
ct.FlowKind.Range,
ct.FlowKind.Func,
] = ct.FlowKind.Value
####################
# - Socket Interaction
####################
def preinit(self, bl_socket: bpy.types.NodeSocket) -> None: def preinit(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Pre-initialize a real Blender node socket from this socket definition. """Pre-initialize a real Blender node socket from this socket definition.
Parameters: Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef. bl_socket: The Blender node socket to alter using data from this SocketDef.
""" """
log.debug('%s: Start Socket Preinit', bl_socket.bl_label) # log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
bl_socket.reset_instance_id() bl_socket.reset_instance_id()
bl_socket.regenerate_dynamic_field_persistance() bl_socket.regenerate_dynamic_field_persistance()
log.debug('%s: End Socket Preinit', bl_socket.bl_label)
bl_socket.active_kind = self.active_kind
# log.debug('%s: End Socket Preinit', bl_socket.bl_label)
def postinit(self, bl_socket: bpy.types.NodeSocket) -> None: def postinit(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Pre-initialize a real Blender node socket from this socket definition. """Pre-initialize a real Blender node socket from this socket definition.
@ -61,12 +72,12 @@ class SocketDef(pyd.BaseModel, abc.ABC):
Parameters: Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef. bl_socket: The Blender node socket to alter using data from this SocketDef.
""" """
log.debug('%s: Start Socket Postinit', bl_socket.bl_label) # log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
bl_socket.is_initializing = False bl_socket.is_initializing = False
bl_socket.on_active_kind_changed() bl_socket.on_active_kind_changed()
bl_socket.on_socket_props_changed(set(bl_socket.blfields)) bl_socket.on_socket_props_changed(set(bl_socket.blfields))
bl_socket.on_data_changed(set(ct.FlowKind)) bl_socket.on_data_changed(set(ct.FlowKind))
log.debug('%s: End Socket Postinit', bl_socket.bl_label) # log.debug('%s: End Socket Postinit', bl_socket.bl_label)
@abc.abstractmethod @abc.abstractmethod
def init(self, bl_socket: bpy.types.NodeSocket) -> None: def init(self, bl_socket: bpy.types.NodeSocket) -> None:
@ -76,6 +87,43 @@ class SocketDef(pyd.BaseModel, abc.ABC):
bl_socket: The Blender node socket to alter using data from this SocketDef. bl_socket: The Blender node socket to alter using data from this SocketDef.
""" """
####################
# - Comparison
####################
def compare(self, bl_socket: bpy.types.NodeSocket) -> bool:
"""Whether this `SocketDef` can be considered to uniquely define the given `bl_socket`.
The general criteria for "uniquely defines" is whether **the same `bl_socket`** could be created using this `SocketDef`.
The extent to which user-altered properties are considered in this regard is a matter of taste, encapsulated entirely within `self.local_compare()`.
Notes:
Used when determining whether to replace sockets with newer variants when synchronizing changes.
**NOTE**: Removing/replacing loose input sockets
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
return (
bl_socket.socket_type is self.socket_type
and bl_socket.active_kind is self.active_kind
and self.local_compare(bl_socket)
)
def local_compare(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Compare this `SocketDef` to an established `bl_socket` in a manner specific to the node.
Notes:
Run by `self.compare()`.
Optionally overriden by individual sockets.
When not overridden, it will always return `False`, indicating that the socket is _never_ uniquely defined by this `SocketDef`.
Parameters:
bl_socket: The Blender node socket to alter using data from this SocketDef.
"""
return False
#################### ####################
# - Serialization # - Serialization
#################### ####################
@ -426,8 +474,34 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Parameters: Parameters:
socket_kinds: The altered `ct.FlowKind`s flowing through. socket_kinds: The altered `ct.FlowKind`s flowing through.
""" """
# Run Socket Callbacks
self.on_socket_data_changed(socket_kinds) self.on_socket_data_changed(socket_kinds)
# Mark Active FlowKind Links as Invalid
## -> Mark link as invalid (very red) if a FlowSignal is traveling.
## -> This helps explain why whatever isn't working isn't working.
## -> TODO: We need a different approach.
# log.debug(
# '[%s] Checking FlowKind Validity (socket_kinds=%s)',
# self.name,
# str(socket_kinds),
# )
# if self.is_linked and not self.is_output:
# link = self.links[0]
# linked_flow = self.compute_data(kind=self.active_kind)
# if (
# link.is_valid
# and self.active_kind in socket_kinds
# and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
# ):
# node_tree = self.id_data
# node_tree.report_link_validity(link, False)
# elif not link.is_valid:
# node_tree = self.id_data
# node_tree.report_link_validity(link, True)
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None: def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
"""Called when `ct.FlowEvent.DataChanged` flows through this socket. """Called when `ct.FlowEvent.DataChanged` flows through this socket.
@ -479,7 +553,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows. The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows.
""" """
# log.debug( # log.debug(
# '[%s] [%s] Triggered (socket_kinds=%s)', # '[%s] [%s] Socket Triggered (socket_kinds=%s)',
# self.name, # self.name,
# event, # event,
# str(socket_kinds), # str(socket_kinds),
@ -757,7 +831,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
linked_values = [link.from_socket.compute_data(kind) for link in self.links] linked_values = [link.from_socket.compute_data(kind) for link in self.links]
# Return Single Value / List of Values # Return Single Value / List of Values
## -> Multi-input sockets are not yet supported. ## -> Multi-input sockets are not (yet) supported.
if linked_values: if linked_values:
return linked_values[0] return linked_values[0]
@ -891,10 +965,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# FlowKind Draw Row # FlowKind Draw Row
col = row.column(align=True) col = row.column(align=True)
{ {
ct.FlowKind.Capabilities: lambda *_: None,
ct.FlowKind.Previews: lambda *_: None,
ct.FlowKind.Value: self.draw_value, ct.FlowKind.Value: self.draw_value,
ct.FlowKind.Array: self.draw_array, ct.FlowKind.Array: self.draw_array,
ct.FlowKind.Range: self.draw_lazy_range, ct.FlowKind.Range: self.draw_lazy_range,
ct.FlowKind.Func: self.draw_lazy_func, ct.FlowKind.Func: self.draw_lazy_func,
ct.FlowKind.Params: lambda *_: None,
ct.FlowKind.Info: lambda *_: None,
}[self.active_kind](col) }[self.active_kind](col)
# Info Drawing # Info Drawing

View File

@ -51,6 +51,16 @@ class BoolBLSocket(base.MaxwellSimSocket):
def value(self, value: bool) -> None: def value(self, value: bool) -> None:
self.raw_value = value self.raw_value = value
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
#################### ####################
# - Socket Configuration # - Socket Configuration

View File

@ -130,6 +130,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'physical_type', 'physical_type',
'unit', 'unit',
'size', 'size',
'value',
} }
) )
def output_sym(self) -> sim_symbols.SimSymbol | None: def output_sym(self) -> sim_symbols.SimSymbol | None:
@ -140,13 +141,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
Raises: Raises:
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`. NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
""" """
if self.symbols: match self.active_kind:
if self.active_kind in [ct.FlowKind.Value, ct.FlowKind.Func]: case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
return self._parse_expr_symbol( return self._parse_expr_symbol(
self._parse_expr_str(self.raw_value_spstr) self._parse_expr_str(self.raw_value_spstr)
) )
if self.active_kind is ct.FlowKind.Range: case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
exclude_zero=(
not self.value.is_zero
if self.value.is_zero is not None
else False
),
## TODO: Does this work for matrix elements?
)
case ct.FlowKind.Range if self.symbols:
## TODO: Support RangeFlow ## TODO: Support RangeFlow
## -- It's hard; we need a min-span set over bound domains. ## -- It's hard; we need a min-span set over bound domains.
## -- We... Don't use this anywhere. Yet? ## -- We... Don't use this anywhere. Yet?
@ -159,20 +176,37 @@ class ExprBLSocket(base.MaxwellSimSocket):
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty' msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
raise NotImplementedError(msg) raise NotImplementedError(msg)
raise NotImplementedError case ct.FlowKind.Range if not self.symbols:
return sim_symbols.SimSymbol( return sim_symbols.SimSymbol(
sym_name=self.output_name, sym_name=self.output_name,
mathtype=self.mathtype, mathtype=self.mathtype,
physical_type=self.physical_type, physical_type=self.physical_type,
unit=self.unit, unit=self.unit,
rows=self.size.rows, rows=self.lazy_range.steps,
cols=self.size.cols, cols=1,
exclude_zero=not self.lazy_range.is_always_nonzero,
) )
####################
# - Value|Range Swapper
####################
use_value_range_swapper: bool = bl_cache.BLField(False)
selected_value_range: ct.FlowKind = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_range(),
)
def _value_or_range(self):
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
]
#################### ####################
# - Symbols # - Symbols
#################### ####################
lazy_range_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
output_name: sim_symbols.SimSymbolName = bl_cache.BLField( output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr sim_symbols.SimSymbolName.Expr
) )
@ -343,7 +377,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
See `MaxwellSimTree` for more detail on the link callbacks. See `MaxwellSimTree` for more detail on the link callbacks.
""" """
## NODE: Depends on suppressed on_prop_changed ## NOTE: Depends on suppressed on_prop_changed
if ct.FlowKind.Info in socket_kinds: if ct.FlowKind.Info in socket_kinds:
info = self.compute_data(kind=ct.FlowKind.Info) info = self.compute_data(kind=ct.FlowKind.Info)
@ -371,7 +405,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
See `MaxwellSimTree` for more detail on the link callbacks. See `MaxwellSimTree` for more detail on the link callbacks.
""" """
## NODE: Depends on suppressed on_prop_changed ## NOTE: Depends on suppressed on_prop_changed
if ('selected_value_range', 'invalidate') in cleared_blfields:
self.active_kind = self.selected_value_range
self.on_active_kind_changed()
# Conditional Unit-Conversion # Conditional Unit-Conversion
## -> This is niche functionality, but the only way to convert units. ## -> This is niche functionality, but the only way to convert units.
@ -757,7 +794,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
@bl_cache.cached_bl_property( @bl_cache.cached_bl_property(
depends_on={ depends_on={
'value', 'value',
'symbols',
'sorted_sp_symbols', 'sorted_sp_symbols',
'sorted_symbols', 'sorted_symbols',
'output_sym', 'output_sym',
@ -769,83 +805,88 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`. If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`. Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
""" """
# Symbolic if self.output_sym is not None:
## -> `self.value` is guaranteed to be an expression with unknowns. match self.active_kind:
## -> The function computes `self.value` with unknowns as arguments. case ct.FlowKind.Value | ct.FlowKind.Func if (
if self.symbols: self.sorted_symbols and not ct.FlowSignal.check(self.value)
value = self.value ):
has_value = not ct.FlowSignal.check(value)
output_sym = self.output_sym
if output_sym is not None and has_value:
return ct.FuncFlow( return ct.FuncFlow(
func=sp.lambdify( func=sp.lambdify(
self.sorted_sp_symbols, self.sorted_sp_symbols,
output_sym.conform(value, strip_unit=True), self.output_sym.conform(self.value, strip_unit=True),
'jax', 'jax',
), ),
func_args=list(self.sorted_symbols), func_args=list(self.sorted_symbols),
func_output=self.output_sym,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending
# Constant case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
## -> When a `self.value` has no unknowns, use a dummy function.
## -> ("Dummy" as in returns the same argument that it takes).
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
## -> Generally only useful for operations with other expressions.
return ct.FuncFlow( return ct.FuncFlow(
func=lambda v: v, func=lambda v: v,
func_args=[self.output_sym], func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True, supports_jax=True,
) )
@bl_cache.cached_bl_property(depends_on={'sorted_symbols'}) case ct.FlowKind.Range if self.sorted_symbols:
def is_differentiable(self) -> bool: msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
"""Whether all symbols are differentiable. raise NotImplementedError(msg)
If there are no symbols, then there is nothing to differentiate, and thus the expression is differentiable. case ct.FlowKind.Range if (
""" not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
if not self.sorted_symbols: ):
return True return ct.FuncFlow(
func=lambda v: v,
return all( func_args=[self.output_sym],
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex] func_output=self.output_sym,
for sym in self.sorted_symbols supports_jax=True,
) )
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym', 'value'}) return ct.FlowSignal.FlowPending
@bl_cache.cached_bl_property(
depends_on={'sorted_symbols', 'output_sym', 'value', 'lazy_range'}
)
def params(self) -> ct.ParamsFlow: def params(self) -> ct.ParamsFlow:
"""Returns parameter symbols/values to accompany `self.lazy_func`. """Returns parameter symbols/values to accompany `self.lazy_func`.
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use). If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`. Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
""" """
# Symbolic
## -> The Expr socket does not declare actual values for the symbols.
## -> They should be realized later, ex. in a Viz node.
## -> Therefore, we just dump the symbols. Easy!
## -> NOTE: func_args must have the same symbol order as was lambdified.
if self.sorted_symbols:
output_sym = self.output_sym output_sym = self.output_sym
if output_sym is not None: if output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.ParamsFlow( return ct.ParamsFlow(
arg_targets=list(self.sorted_symbols), arg_targets=list(self.sorted_symbols),
func_args=[sym.sp_symbol for sym in self.sorted_symbols], func_args=[sym.sp_symbol for sym in self.sorted_symbols],
symbols=self.sorted_symbols, symbols=set(self.sorted_symbols),
is_differentiable=self.is_differentiable,
) )
return ct.FlowSignal.FlowPending
# Constant case ct.FlowKind.Value | ct.FlowKind.Func if (
## -> Simply pass self.value verbatim as a function argument. not self.sorted_symbols and not ct.FlowSignal.check(self.value)
## -> Easy dice, easy life! ):
return ct.ParamsFlow( return ct.ParamsFlow(
arg_targets=[self.output_sym], arg_targets=[self.output_sym],
func_args=[self.value], func_args=[self.value],
is_differentiable=self.is_differentiable,
) )
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.output_sym.sp_symbol_matsym],
symbols={self.output_sym},
).realize_partial({self.output_sym: self.lazy_range})
return ct.FlowSignal.FlowPending
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'}) @bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
def info(self) -> ct.InfoFlow: def info(self) -> ct.InfoFlow:
r"""Returns parameter symbols/values to accompany `self.lazy_func`. r"""Returns parameter symbols/values to accompany `self.lazy_func`.
@ -858,22 +899,34 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
""" """
# Constant
## -> The input SimSymbols become continuous dimensional indices.
## -> All domain validity information is defined on the SimSymbol keys.
if self.sorted_symbols:
output_sym = self.output_sym output_sym = self.output_sym
if output_sym is not None: if output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.InfoFlow( return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols}, dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym, output=self.output_sym,
) )
return ct.FlowSignal.FlowPending
# Constant case ct.FlowKind.Value | ct.FlowKind.Func if (
## -> We only need the output symbol to describe the raw data. not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(output=self.output_sym) return ct.InfoFlow(output=self.output_sym)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(
dims={self.output_sym: self.lazy_range},
output=self.output_sym.update(rows=1),
)
return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind: Capabilities # - FlowKind: Capabilities
#################### ####################
@ -1039,6 +1092,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI. However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI.
""" """
if self.use_value_range_swapper:
col.prop(self, self.blfields['selected_value_range'], text='')
if self.symbols: if self.symbols:
col.prop(self, self.blfields['raw_value_spstr'], text='') col.prop(self, self.blfields['raw_value_spstr'], text='')
@ -1097,6 +1153,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps. If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
As such, `self.steps` won't be exposed in the UI. As such, `self.steps` won't be exposed in the UI.
""" """
if self.use_value_range_swapper:
col.prop(self, self.blfields['selected_value_range'], text='')
if self.symbols: if self.symbols:
col.prop(self, self.blfields['raw_min_spstr'], text='') col.prop(self, self.blfields['raw_min_spstr'], text='')
col.prop(self, self.blfields['raw_max_spstr'], text='') col.prop(self, self.blfields['raw_max_spstr'], text='')
@ -1198,13 +1257,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
# - Socket Configuration # - Socket Configuration
#################### ####################
class ExprSocketDef(base.SocketDef): class ExprSocketDef(base.SocketDef):
"""Interface for defining an `ExprSocket`."""
socket_type: ct.SocketType = ct.SocketType.Expr socket_type: ct.SocketType = ct.SocketType.Expr
active_kind: typ.Literal[
ct.FlowKind.Value,
ct.FlowKind.Range,
ct.FlowKind.Func,
] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
use_value_range_swapper: bool = False
# Socket Interface # Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -1458,7 +1515,7 @@ class ExprSocketDef(base.SocketDef):
# Check ActiveKind and Size # Check ActiveKind and Size
## -> NOTE: This doesn't protect against dynamic changes to either. ## -> NOTE: This doesn't protect against dynamic changes to either.
if ( if (
self.active_kind == ct.FlowKind.Range self.active_kind is ct.FlowKind.Range
and self.size is not spux.NumberSize1D.Scalar and self.size is not spux.NumberSize1D.Scalar
): ):
msg = "Can't have a non-Scalar size when Range is set as the active kind." msg = "Can't have a non-Scalar size when Range is set as the active kind."
@ -1504,9 +1561,9 @@ class ExprSocketDef(base.SocketDef):
# - Initialization # - Initialization
#################### ####################
def init(self, bl_socket: ExprBLSocket) -> None: def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind
bl_socket.output_name = self.output_name bl_socket.output_name = self.output_name
bl_socket.use_linked_capabilities = True bl_socket.use_linked_capabilities = True
bl_socket.use_value_range_swapper = self.use_value_range_swapper
# Socket Interface # Socket Interface
## -> Recall that auto-updates are turned off during init() ## -> Recall that auto-updates are turned off during init()
@ -1543,6 +1600,25 @@ class ExprSocketDef(base.SocketDef):
# Info Draw # Info Draw
bl_socket.use_info_draw = True bl_socket.use_info_draw = True
def local_compare(self, bl_socket: ExprBLSocket) -> None:
"""Determine whether an updateable socket should be re-initialized from this `SocketDef`."""
def cmp(attr: str):
return getattr(bl_socket, attr) == getattr(self, attr)
return (
bl_socket.use_linked_capabilities
and cmp('output_name')
and cmp('use_value_range_swapper')
and cmp('size')
and cmp('mathtype')
and cmp('physical_type')
and cmp('show_func_ui')
and cmp('show_info_columns')
and cmp('show_name_selector')
and bl_socket.use_info_draw
)
#################### ####################
# - Blender Registration # - Blender Registration

View File

@ -117,6 +117,16 @@ class MaxwellBoundCondsBLSocket(base.MaxwellSimSocket):
), ),
) )
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
#################### ####################
# - Socket Configuration # - Socket Configuration

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from ... import contracts as ct from ... import contracts as ct
from .. import base from .. import base
@ -32,6 +34,9 @@ class MaxwellFDTDSimSocketDef(base.SocketDef):
def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None: def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None:
pass pass
def local_compare(self, _: MaxwellFDTDSimBLSocket) -> None:
return True
#################### ####################
# - Blender Registration # - Blender Registration

View File

@ -73,7 +73,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
def value(self, eps_rel: tuple[float, float]) -> None: def value(self, eps_rel: tuple[float, float]) -> None:
self.eps_rel = eps_rel self.eps_rel = eps_rel
@bl_cache.cached_bl_property(depends_on={'value', 'differentiable'}) @bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow: def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow( return ct.FuncFlow(
func=lambda: self.value, func=lambda: self.value,
@ -82,7 +82,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
@bl_cache.cached_bl_property(depends_on={'differentiable'}) @bl_cache.cached_bl_property(depends_on={'differentiable'})
def params(self) -> ct.FuncFlow: def params(self) -> ct.FuncFlow:
return ct.ParamsFlow(is_differentiable=self.differentiable) return ct.ParamsFlow()
#################### ####################
# - UI # - UI

View File

@ -29,11 +29,11 @@ class MaxwellMonitorBLSocket(base.MaxwellSimSocket):
class MaxwellMonitorSocketDef(base.SocketDef): class MaxwellMonitorSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor
is_list: bool = False
def init(self, bl_socket: MaxwellMonitorBLSocket) -> None: def init(self, bl_socket: MaxwellMonitorBLSocket) -> None:
if self.is_list: pass
bl_socket.active_kind = ct.FlowKind.Array
def local_compare(self, _: MaxwellMonitorBLSocket) -> None:
return True
#################### ####################

View File

@ -55,6 +55,16 @@ class MaxwellSimGridBLSocket(base.MaxwellSimSocket):
min_steps_per_wvl=self.min_steps_per_wl, min_steps_per_wvl=self.min_steps_per_wl,
) )
@bl_cache.cached_bl_property(depends_on={'value'})
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
)
@bl_cache.cached_bl_property()
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow()
#################### ####################
# - Socket Configuration # - Socket Configuration

View File

@ -29,11 +29,11 @@ class MaxwellSourceBLSocket(base.MaxwellSimSocket):
class MaxwellSourceSocketDef(base.SocketDef): class MaxwellSourceSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellSource socket_type: ct.SocketType = ct.SocketType.MaxwellSource
is_list: bool = False
def init(self, bl_socket: MaxwellSourceBLSocket) -> None: def init(self, bl_socket: MaxwellSourceBLSocket) -> None:
if self.is_list: pass
bl_socket.active_kind = ct.FlowKind.Array
def local_compare(self, _: MaxwellSourceBLSocket) -> None:
return True
#################### ####################

View File

@ -29,11 +29,11 @@ class MaxwellStructureBLSocket(base.MaxwellSimSocket):
class MaxwellStructureSocketDef(base.SocketDef): class MaxwellStructureSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellStructure socket_type: ct.SocketType = ct.SocketType.MaxwellStructure
is_list: bool = False
def init(self, bl_socket: MaxwellStructureBLSocket) -> None: def init(self, bl_socket: MaxwellStructureBLSocket) -> None:
if self.is_list: pass
bl_socket.active_kind = ct.FlowKind.Array
def local_compare(self, _: MaxwellStructureBLSocket) -> None:
return True
#################### ####################

View File

@ -16,10 +16,10 @@
"""Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes.""" """Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes."""
from ..keyed_cache import KeyedCache, keyed_cache
from .bl_field import BLField from .bl_field import BLField
from .bl_prop import BLProp, BLPropType from .bl_prop import BLProp, BLPropType
from .cached_bl_property import CachedBLProperty, cached_bl_property from .cached_bl_property import CachedBLProperty, cached_bl_property
from .keyed_cache import KeyedCache, keyed_cache
from .managed_cache import invalidate_nonpersist_instance_id from .managed_cache import invalidate_nonpersist_instance_id
from .signal import Signal from .signal import Signal

View File

@ -21,6 +21,7 @@ from types import MappingProxyType
import bpy import bpy
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils.keyed_cache import keyed_cache
InstanceID: typ.TypeAlias = str ## Stringified UUID4 InstanceID: typ.TypeAlias = str ## Stringified UUID4
@ -220,11 +221,14 @@ class BLInstance:
for str_search_prop_name in self.blfields_str_search: for str_search_prop_name in self.blfields_str_search:
setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch) setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch)
@keyed_cache(
exclude={'self'}, ## No dynamic elements of 'self' can be used.
)
def trace_blfields_to_clear( def trace_blfields_to_clear(
self, self,
prop_name: str, prop_name: str,
prev_blfields_to_clear: list[ prev_blfields_to_clear: tuple[
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']] tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']], ...
] = (), ] = (),
) -> list[str]: ) -> list[str]:
"""Invalidates all properties that depend on `prop_name`. """Invalidates all properties that depend on `prop_name`.
@ -239,7 +243,7 @@ class BLInstance:
All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`). All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`).
""" """
if prev_blfields_to_clear: if prev_blfields_to_clear:
blfields_to_clear = prev_blfields_to_clear.copy() blfields_to_clear = list(prev_blfields_to_clear)
else: else:
blfields_to_clear = [] blfields_to_clear = []
@ -268,7 +272,7 @@ class BLInstance:
if dst_prop_name in self.blfields: if dst_prop_name in self.blfields:
blfields_to_clear += self.trace_blfields_to_clear( blfields_to_clear += self.trace_blfields_to_clear(
dst_prop_name, dst_prop_name,
prev_blfields_to_clear=blfields_to_clear, prev_blfields_to_clear=tuple(blfields_to_clear),
) )
match (bool(prev_blfields_to_clear), bool(blfields_to_clear)): match (bool(prev_blfields_to_clear), bool(blfields_to_clear)):
@ -297,7 +301,7 @@ class BLInstance:
## -> As such, deduplication would not be wrong, just extraneous. ## -> As such, deduplication would not be wrong, just extraneous.
## -> Since invalidation is in a hot-loop, don't do such things. ## -> Since invalidation is in a hot-loop, don't do such things.
case (True, True): case (True, True):
return blfields_to_clear return list(reversed(dict.fromkeys(reversed(blfields_to_clear))))
def clear_blfields_after(self, prop_name: str) -> list[str]: def clear_blfields_after(self, prop_name: str) -> list[str]:
"""Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`. """Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`.

View File

@ -17,6 +17,7 @@
"""Useful image processing operations for use in the addon.""" """Useful image processing operations for use in the addon."""
import enum import enum
import functools
import typing as typ import typing as typ
import jax import jax
@ -26,7 +27,6 @@ import matplotlib
import matplotlib.axis as mpl_ax import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg import matplotlib.backends.backend_agg
import matplotlib.figure import matplotlib.figure
import numpy as np
import seaborn as sns import seaborn as sns
from blender_maxwell import contracts as ct from blender_maxwell import contracts as ct
@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
#################### ####################
# - MPL Helpers # - MPL Helpers
#################### ####################
# @functools.lru_cache(maxsize=16) @functools.lru_cache(maxsize=4)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int): def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure( fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight' figsize=[width_inches, height_inches], dpi=dpi, layout='tight'

View File

@ -18,11 +18,15 @@ import functools
import inspect import inspect
import typing as typ import typing as typ
from blender_maxwell.utils import bl_instance, logger, serialize from blender_maxwell.utils import logger, serialize
log = logger.get(__name__) log = logger.get(__name__)
class BLInstance(typ.Protocol):
instance_id: str
class KeyedCache: class KeyedCache:
def __init__( def __init__(
self, self,
@ -75,8 +79,8 @@ class KeyedCache:
def __get__( def __get__(
self, self,
bl_instance: bl_instance.BLInstance | None, bl_instance: BLInstance | None,
owner: type[bl_instance.BLInstance], owner: type[BLInstance],
) -> typ.Callable: ) -> typ.Callable:
_func = functools.partial(self, bl_instance) _func = functools.partial(self, bl_instance)
_func.invalidate = functools.partial( _func.invalidate = functools.partial(
@ -110,7 +114,7 @@ class KeyedCache:
def invalidate( def invalidate(
self, self,
bl_instance: bl_instance.BLInstance | None, bl_instance: BLInstance | None,
**arguments: dict[str, typ.Any], **arguments: dict[str, typ.Any],
) -> dict[str, typ.Any]: ) -> dict[str, typ.Any]:
# Determine Wildcard Arguments # Determine Wildcard Arguments

View File

@ -264,13 +264,16 @@ class SimSymbol(pyd.BaseModel):
interval_closed_im: tuple[bool, bool] = (False, False) interval_closed_im: tuple[bool, bool] = (False, False)
#################### ####################
# - Labels # - Core
#################### ####################
@functools.cached_property @functools.cached_property
def name(self) -> str: def name(self) -> str:
"""Usable name for the symbol.""" """Usable name for the symbol."""
return self.sym_name.name return self.sym_name.name
####################
# - Labels
####################
@functools.cached_property @functools.cached_property
def name_pretty(self) -> str: def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing.""" """Pretty (possibly unicode) name for the thing."""
@ -307,6 +310,8 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property @functools.cached_property
def plot_label(self) -> str: def plot_label(self) -> str:
"""Pretty plot-oriented label.""" """Pretty plot-oriented label."""
if self.unit is None:
return self.name_pretty
return f'{self.name_pretty} ({self.unit_label})' return f'{self.name_pretty} ({self.unit_label})'
#################### ####################
@ -420,6 +425,11 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property @functools.cached_property
def is_nonzero(self) -> bool: def is_nonzero(self) -> bool:
"""Whether or not the value of this symbol can ever be $0$.
Notes:
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
"""
if self.exclude_zero: if self.exclude_zero:
return True return True
@ -441,6 +451,18 @@ class SimSymbol(pyd.BaseModel):
) )
return check_real_domain(self.domain) return check_real_domain(self.domain)
@functools.cached_property
def can_diff(self) -> bool:
"""Whether this symbol can be used as the input / output variable when differentiating."""
# Check Constants
## -> Constants (w/pinned values) are never differentiable.
if self.is_constant:
return False
# TODO: Discontinuities (especially across 0)?
return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
#################### ####################
# - Properties # - Properties
#################### ####################
@ -664,8 +686,10 @@ class SimSymbol(pyd.BaseModel):
res = spux.strip_unit_system(sp_obj) res = spux.strip_unit_system(sp_obj)
# Broadcast Expansion # Broadcast Expansion
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase): if (self.rows > 1 or self.cols > 1) and not isinstance(
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols) res, sp.MatrixBase | sp.MatrixSymbol
):
res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
return res return res
@ -753,7 +777,9 @@ class SimSymbol(pyd.BaseModel):
unit = None unit = None
# Rows/Cols from Expr (if Matrix) # Rows/Cols from Expr (if Matrix)
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1) rows, cols = (
expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1)
)
return SimSymbol( return SimSymbol(
sym_name=sym_name, sym_name=sym_name,