feat: various sym-flow modifications
parent
830b316e01
commit
38e70a60d3
|
@ -14,12 +14,12 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import jaxtyping as jtyp
|
||||
import numpy as np
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
@ -29,8 +29,7 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class ArrayFlow:
|
||||
class ArrayFlow(pyd.BaseModel):
|
||||
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
|
||||
|
||||
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
|
||||
|
@ -41,7 +40,9 @@ class ArrayFlow:
|
|||
None if unitless.
|
||||
"""
|
||||
|
||||
values: jtyp.Shaped[jtyp.Array, '...']
|
||||
model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||
|
||||
values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
|
||||
unit: spux.Unit | None = None
|
||||
|
||||
is_sorted: bool = False
|
||||
|
|
|
@ -18,6 +18,7 @@ import enum
|
|||
import functools
|
||||
import typing as typ
|
||||
|
||||
from blender_maxwell.contracts import BLEnumElement
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils.staticproperty import staticproperty
|
||||
|
@ -99,6 +100,17 @@ class FlowKind(enum.StrEnum):
|
|||
def to_icon(_: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return FlowKind.to_name(self)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
return FlowKind.to_icon(self)
|
||||
|
||||
def bl_enum_element(self, i) -> BLEnumElement:
|
||||
return (str(self), self.name, self.name, self.icon, i)
|
||||
|
||||
####################
|
||||
# - Static Properties
|
||||
####################
|
||||
|
@ -162,7 +174,7 @@ class FlowKind(enum.StrEnum):
|
|||
def socket_shape(self) -> str:
|
||||
"""Return the socket shape associated with this `FlowKind`.
|
||||
|
||||
**ONLY** valid for `FlowKind`s that can be considered "active".
|
||||
Should generally only be used with `active_kinds`.
|
||||
|
||||
Raises:
|
||||
ValueError: If this `FlowKind` cannot ever be considered "active".
|
||||
|
@ -172,7 +184,7 @@ class FlowKind(enum.StrEnum):
|
|||
FlowKind.Array: 'SQUARE',
|
||||
FlowKind.Range: 'SQUARE',
|
||||
FlowKind.Func: 'DIAMOND',
|
||||
}[self]
|
||||
}.get(self, 'CIRCLE')
|
||||
|
||||
####################
|
||||
# - Class Methods
|
||||
|
|
|
@ -14,40 +14,17 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax
|
||||
import jaxtyping as jtyp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .info import InfoFlow
|
||||
from .lazy_range import RangeFlow
|
||||
from .params import ParamsFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class FuncFlow:
|
||||
r"""Defines a flow of data as incremental function composition.
|
||||
|
||||
For specific math system usage instructions, please consult the documentation of relevant nodes.
|
||||
r"""Implements the core of the math system via `FuncFlow`, which allows high-performance, fully-expressive workflows with data that can be "very large", and/or whose input parameters are not yet fully known.
|
||||
|
||||
# Introduction
|
||||
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
|
||||
Doing so has several advantages:
|
||||
|
||||
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy.
|
||||
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact.
|
||||
- **Performant**: Since no operations are happening, the UI feels fast and snappy.
|
||||
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy, greatly boosting the will to experiment and ultimate productivity.
|
||||
- **Symbolic**: Since no numerical math is being done yet, we can inject symbolic variables at-will, enabling effortless ex. parameter sweeping, band-structure generation, differentiable can choose to keep our input parameters as symbolic variables with no performance impact.
|
||||
- **Performant**: Since the data pipeline is built as a single function w/o side effects, that function can be often be JIT-optimized for highly-optimized execution on the instruction sets used by modern massively-parallel devices, like modern CPUs (SSE/AVX), GPUs (ex. PTX), and HPC clusters (network-sharding to ARM/x86).
|
||||
|
||||
The result is a math system optimized for the analysis typically needed in electrodynamic contexts, prioritizing clarity and flexibility at soft-real-time, even with gigabytes of data on relatively weak hardware.
|
||||
|
||||
## Strongly Related FlowKinds
|
||||
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
|
||||
|
@ -238,6 +215,35 @@ class FuncFlow:
|
|||
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
|
||||
|
||||
But above all, we hope that this math system is fun, practical, and maybe even interesting.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax
|
||||
import jaxtyping as jtyp
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .info import InfoFlow
|
||||
from .lazy_range import RangeFlow
|
||||
from .params import ParamsFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||
|
||||
|
||||
class FuncFlow(pyd.BaseModel):
|
||||
r"""Defines a flow of data as incremental function composition.
|
||||
|
||||
For theoretical information, please see the documentation of this module.
|
||||
For specific math system usage instructions, please consult the documentation of relevant nodes.
|
||||
|
||||
Attributes:
|
||||
func: The function that generates the represented value.
|
||||
|
@ -247,14 +253,16 @@ class FuncFlow:
|
|||
See the documentation of `self.func_jax()`.
|
||||
"""
|
||||
|
||||
model_config = pyd.ConfigDict(frozen=True)
|
||||
|
||||
func: LazyFunction
|
||||
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
|
||||
func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
|
||||
func_output: sim_symbols.SimSymbol | None = None
|
||||
|
||||
supports_jax: bool = False
|
||||
|
||||
concatenated: bool = False
|
||||
is_concatenated: bool = False
|
||||
|
||||
####################
|
||||
# - Functions
|
||||
|
@ -318,6 +326,7 @@ class FuncFlow:
|
|||
{}
|
||||
),
|
||||
) -> typ.Self:
|
||||
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
|
||||
if self.supports_jax:
|
||||
return self.func_jax(
|
||||
*params.scaled_func_args(symbol_values),
|
||||
|
@ -371,14 +380,55 @@ class FuncFlow:
|
|||
|
||||
return data | {info.output: self.realize(params, symbol_values=symbol_values)}
|
||||
|
||||
def realize_partial(
|
||||
self, params: ParamsFlow
|
||||
) -> typ.Callable[
|
||||
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||
jtyp.Inexact[jtyp.Array, '...'],
|
||||
]:
|
||||
"""Create a purely-numerical function, which takes only numerical.
|
||||
|
||||
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
|
||||
|
||||
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
|
||||
By using `realize_partial()`, two things are ensured:
|
||||
|
||||
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
|
||||
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
|
||||
|
||||
Notes:
|
||||
Be **very careful about units**.
|
||||
Ideally, the bottom function should use `.scale_to_unit()` before invoking `.compose_within()` with the output of this function.
|
||||
"""
|
||||
pre_realized_syms = list(
|
||||
params.realize_symbols(params.realized_symbols, allow_partial=True).values()
|
||||
)
|
||||
|
||||
def realizer(
|
||||
*sym_args: int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||
) -> jtyp.Inexact[jtyp.Array, '...']:
|
||||
return self.func(
|
||||
*[
|
||||
func_arg_n(*sym_args, *pre_realized_syms)
|
||||
for func_arg_n in params.func_args_n
|
||||
],
|
||||
**{
|
||||
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
|
||||
for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
|
||||
},
|
||||
)
|
||||
|
||||
return realizer
|
||||
|
||||
####################
|
||||
# - Composition Operations
|
||||
# - Operations
|
||||
####################
|
||||
def compose_within(
|
||||
self,
|
||||
enclosing_func: LazyFunction,
|
||||
enclosing_func_args: list[type] = (),
|
||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||
enclosing_func_args: list[sim_symbols.SimSymbol] = (),
|
||||
enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
|
||||
enclosing_func_output: sim_symbols.SimSymbol | None = None,
|
||||
supports_jax: bool = False,
|
||||
) -> typ.Self:
|
||||
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
|
||||
|
@ -415,6 +465,10 @@ class FuncFlow:
|
|||
Returns:
|
||||
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
|
||||
"""
|
||||
## TODO: Support unit system conversion at the point of composition.
|
||||
## -- This may require us to track the units of the function output.
|
||||
## TODO: Support JAX-evaluation when jax support changes from True to False.
|
||||
## -- This would allow big data flows to compose performantly as arguments into non-JAX functions.
|
||||
return FuncFlow(
|
||||
func=lambda *args, **kwargs: enclosing_func(
|
||||
self.func(
|
||||
|
@ -426,6 +480,7 @@ class FuncFlow:
|
|||
),
|
||||
func_args=self.func_args + list(enclosing_func_args),
|
||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||
func_output=enclosing_func_output,
|
||||
supports_jax=self.supports_jax and supports_jax,
|
||||
)
|
||||
|
||||
|
@ -472,7 +527,7 @@ class FuncFlow:
|
|||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
)
|
||||
if not self.concatenated:
|
||||
if not self.is_concatenated:
|
||||
return (ret,)
|
||||
return ret
|
||||
|
||||
|
@ -487,5 +542,83 @@ class FuncFlow:
|
|||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
supports_jax=self.supports_jax and other.supports_jax,
|
||||
concatenated=True,
|
||||
is_concatenated=True,
|
||||
)
|
||||
|
||||
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
|
||||
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
|
||||
|
||||
`unit` must be manually guaranteed to be compatible with `self.unit`.
|
||||
"""
|
||||
if self.func_output is not None:
|
||||
# Retrieve Output Unit
|
||||
output_unit = self.func_output.unit
|
||||
|
||||
# Compile Efficient Unit-Conversion Function
|
||||
a = self.func_output.mathtype.sp_symbol_a
|
||||
unit_convert_expr = (
|
||||
spux.scale_to_unit(a * output_unit, unit)
|
||||
if self.func_output.unit is not None
|
||||
else a
|
||||
)
|
||||
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
|
||||
|
||||
# Compose Unit-Converted FuncFlow
|
||||
return self.compose_within(
|
||||
enclosing_func=unit_convert_func,
|
||||
supports_jax=True,
|
||||
enclosing_func_output=self.func_output.update(unit=unit),
|
||||
)
|
||||
|
||||
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
|
||||
raise ValueError(msg)
|
||||
|
||||
def scale_to_unit_system(
|
||||
self, unit_system: spux.UnitSystem | None = None
|
||||
) -> typ.Self:
|
||||
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
|
||||
|
||||
Using `self.output_symbol`, which tracks the units of the output, we can determine a scaling factor to multiply the (numerical) function output by in order to conform it to the given unit system.
|
||||
|
||||
In general, **don't use this**.
|
||||
Any superfluous numerical operations in a data pipeline can enhance instabilities and interfere with JIT-optimization (floating-point arithmetic isn't commutative, for example).
|
||||
However, occasionally, we need to "intercept" a lazy data flow, for example when realizing a `FlowKind.Value` that doesn't understand symbols or units - but which only accepts a float/complex scalar/array with pre-determined unit convention.
|
||||
|
||||
For this purpose alone, this method is provided to pre-scale a `FuncFlow`, just before using `realize()` / `__or__` and then `realize()`.
|
||||
**To encourage proper usage** (and ease implementation), the output unit in `self.func_output` of the output will be reset to `None` - indicating that the output can only be handled as a unitless scalar w/semantic meaning tracked elsewhere.
|
||||
|
||||
Notes:
|
||||
**ONLY** use with output types that support meaningful arbitrary multiplication.
|
||||
|
||||
A scale-only sympy expression will be used to produce an optimized JAX function of a single variable, which will then be composed onto the existing `FuncFlow`.
|
||||
|
||||
Parameters:
|
||||
unit_system: The unit system to conform the function output to.
|
||||
|
||||
Returns:
|
||||
A new `FuncFlow` that conforms to the new unit, but is itself now considered unitless.
|
||||
"""
|
||||
if self.func_output is not None:
|
||||
# Retrieve Output Unit
|
||||
output_unit = self.func_output.unit
|
||||
|
||||
# Compile Efficient Unit-Conversion Function
|
||||
a = self.func_output.mathtype.sp_symbol_a
|
||||
unit_convert_expr = (
|
||||
spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(a * output_unit, unit_system)
|
||||
)
|
||||
if self.func_output.unit is not None
|
||||
else a
|
||||
)
|
||||
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
|
||||
|
||||
# Compose Unit-Converted FuncFlow
|
||||
return self.compose_within(
|
||||
enclosing_func=unit_convert_func,
|
||||
supports_jax=True,
|
||||
enclosing_func_output=self.func_output.update(unit=None),
|
||||
)
|
||||
|
||||
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
|
||||
raise ValueError(msg)
|
||||
|
|
|
@ -14,17 +14,15 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
@ -61,8 +59,7 @@ class ScalingMode(enum.StrEnum):
|
|||
return ''
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class RangeFlow:
|
||||
class RangeFlow(pyd.BaseModel):
|
||||
r"""Represents a finite spaced array using symbolic boundary expressions.
|
||||
|
||||
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||
|
@ -92,8 +89,10 @@ class RangeFlow:
|
|||
symbols: Set of variables from which `start` and/or `stop` are determined.
|
||||
"""
|
||||
|
||||
start: spux.ScalarUnitlessComplexExpr
|
||||
stop: spux.ScalarUnitlessComplexExpr
|
||||
model_config = pyd.ConfigDict(frozen=True)
|
||||
|
||||
start: spux.ScalarUnitlessRealExpr
|
||||
stop: spux.ScalarUnitlessRealExpr
|
||||
steps: int = 0
|
||||
scaling: ScalingMode = ScalingMode.Lin
|
||||
|
||||
|
@ -102,7 +101,7 @@ class RangeFlow:
|
|||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||
|
||||
# Helper Attributes
|
||||
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
|
||||
pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
|
||||
|
||||
####################
|
||||
# - SimSymbol Interop
|
||||
|
@ -218,14 +217,26 @@ class RangeFlow:
|
|||
)
|
||||
return combined_mathtype
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def ideal_midpoint(self) -> spux.SympyExpr:
|
||||
return (self.stop + self.start) / 2
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def ideal_range(self) -> spux.SympyExpr:
|
||||
return self.stop - self.start
|
||||
|
||||
@functools.cached_property
|
||||
def ideal_step_size(self) -> spux.SympyExpr:
|
||||
return self.ideal_range / (self.steps - 1)
|
||||
|
||||
@functools.cached_property
|
||||
def is_always_nonzero(self) -> spux.SympyExpr:
|
||||
if self.start > 0 or self.stop < 0:
|
||||
return True
|
||||
|
||||
is_zero = (self.start % self.ideal_step_size).is_zero
|
||||
return is_zero if is_zero is not None else False
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
####################
|
||||
|
@ -452,7 +463,7 @@ class RangeFlow:
|
|||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
|
||||
) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
|
||||
"""Realize **all** input symbols to the `RangeFlow`.
|
||||
|
||||
Parameters:
|
||||
|
@ -480,7 +491,7 @@ class RangeFlow:
|
|||
raise NotImplementedError(msg)
|
||||
|
||||
realized_syms |= {sym: v}
|
||||
|
||||
return realized_syms
|
||||
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
|
||||
import jaxtyping as jtyp
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
@ -28,31 +28,35 @@ from blender_maxwell.utils import logger, sim_symbols
|
|||
|
||||
from .array import ArrayFlow
|
||||
from .expr_info import ExprInfo
|
||||
from .flow_kinds import FlowKind
|
||||
from .lazy_range import RangeFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class ParamsFlow:
|
||||
class ParamsFlow(pyd.BaseModel):
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
|
||||
arg_targets: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
||||
kwarg_targets: list[str, sim_symbols.SimSymbol] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
model_config = pyd.ConfigDict(frozen=True)
|
||||
|
||||
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
||||
arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
|
||||
kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
|
||||
|
||||
func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
|
||||
func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
|
||||
|
||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||
realized_symbols: dict[
|
||||
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||
] = pyd.Field(default_factory=dict)
|
||||
|
||||
is_differentiable: bool = False
|
||||
@functools.cached_property
|
||||
def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
|
||||
"""Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
|
||||
return {sym for sym in self.symbols if sym.can_diff}
|
||||
|
||||
####################
|
||||
# - Symbols
|
||||
|
@ -78,6 +82,27 @@ class ParamsFlow:
|
|||
"""
|
||||
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||
|
||||
@functools.cached_property
|
||||
def all_sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
key_func = lambda sym: sym.name # noqa: E731
|
||||
return sorted(self.symbols, key=key_func) + sorted(
|
||||
self.realized_symbols.keys(), key=key_func
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def all_sorted_sp_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
return [sym.sp_symbol_matsym for sym in self.all_sorted_symbols]
|
||||
|
||||
####################
|
||||
# - JIT'ed Callables for Numerical Function Arguments
|
||||
####################
|
||||
|
@ -101,7 +126,7 @@ class ParamsFlow:
|
|||
"""
|
||||
return [
|
||||
sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
self.all_sorted_sp_symbols,
|
||||
target_sym.conform(func_arg, strip_unit=True),
|
||||
'jax',
|
||||
)
|
||||
|
@ -127,7 +152,7 @@ class ParamsFlow:
|
|||
"""
|
||||
return {
|
||||
key: sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
self.all_sorted_sp_symbols,
|
||||
self.kwarg_targets[key].conform(func_arg, strip_unit=True),
|
||||
'jax',
|
||||
)
|
||||
|
@ -142,8 +167,9 @@ class ParamsFlow:
|
|||
symbol_values: dict[
|
||||
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||
] = MappingProxyType({}),
|
||||
allow_partial: bool = False,
|
||||
) -> dict[
|
||||
sp.Symbol,
|
||||
sim_symbols.SimSymbol,
|
||||
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
|
||||
]:
|
||||
"""Fully realize all symbols by assigning them a value.
|
||||
|
@ -160,10 +186,12 @@ class ParamsFlow:
|
|||
Returns:
|
||||
A dictionary almost with `.subs()`, other than `jax` arrays.
|
||||
"""
|
||||
if set(self.symbols) == set(symbol_values.keys()):
|
||||
if allow_partial or set(self.all_sorted_symbols) == set(symbol_values.keys()):
|
||||
realized_syms = {}
|
||||
for sym in self.sorted_symbols:
|
||||
sym_value = symbol_values[sym]
|
||||
for sym in self.all_sorted_symbols:
|
||||
sym_value = symbol_values.get(sym)
|
||||
if sym_value is None and allow_partial:
|
||||
continue
|
||||
|
||||
if isinstance(sym_value, spux.SympyType):
|
||||
v = sym.scale(sym_value)
|
||||
|
@ -214,7 +242,9 @@ class ParamsFlow:
|
|||
Parameters:
|
||||
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
|
||||
"""
|
||||
realized_symbols = list(self.realize_symbols(symbol_values).values())
|
||||
realized_symbols = list(
|
||||
self.realize_symbols(symbol_values | self.realized_symbols).values()
|
||||
)
|
||||
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
|
||||
|
||||
def scaled_func_kwargs(
|
||||
|
@ -227,10 +257,11 @@ class ParamsFlow:
|
|||
|
||||
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
|
||||
"""
|
||||
realized_symbols = self.realize_symbols(symbol_values)
|
||||
realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
|
||||
|
||||
return {
|
||||
func_arg_name: func_arg_n(**realized_symbols)
|
||||
for func_arg_name, func_arg_n in self.func_kwargs_n.items()
|
||||
func_arg_name: func_kwarg_n(**realized_symbols)
|
||||
for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -251,7 +282,7 @@ class ParamsFlow:
|
|||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
symbols=self.symbols | other.symbols,
|
||||
is_differentiable=self.is_differentiable and other.is_differentiable,
|
||||
realized_symbols=self.realized_symbols | other.realized_symbols,
|
||||
)
|
||||
|
||||
def compose_within(
|
||||
|
@ -261,7 +292,6 @@ class ParamsFlow:
|
|||
enclosing_func_args: list[spux.SympyExpr] = (),
|
||||
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
|
||||
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
|
||||
enclosing_is_differentiable: bool = False,
|
||||
) -> typ.Self:
|
||||
return ParamsFlow(
|
||||
arg_targets=self.arg_targets + list(enclosing_arg_targets),
|
||||
|
@ -269,13 +299,41 @@ class ParamsFlow:
|
|||
func_args=self.func_args + list(enclosing_func_args),
|
||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||
symbols=self.symbols | enclosing_symbols,
|
||||
is_differentiable=(
|
||||
self.is_differentiable
|
||||
if not enclosing_symbols
|
||||
else (self.is_differentiable & enclosing_is_differentiable)
|
||||
),
|
||||
realized_symbols=self.realized_symbols,
|
||||
)
|
||||
|
||||
def realize_partial(
|
||||
self,
|
||||
symbol_values: dict[
|
||||
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||
],
|
||||
) -> typ.Self:
|
||||
"""Provide a particular expression/range/array to realize some symbols.
|
||||
|
||||
Essentially removes symbols from `self.symbols`, and adds the symbol w/value to `self.realized_symbols`.
|
||||
As a result, only the still-unrealized symbols need to be passed at the time of realization (using ex. `self.scaled_func_args()`).
|
||||
|
||||
Parameters:
|
||||
symbol_values: The value to realize for each `SimSymbol`.
|
||||
**All keys** must be identically matched to a single element of `self.symbol`.
|
||||
Can be empty, in which case an identical new `ParamsFlow` will be returned.
|
||||
|
||||
Raises:
|
||||
ValueError: If any symbol in `symbol_values`
|
||||
"""
|
||||
syms = set(symbol_values.keys())
|
||||
if syms.issubset(self.symbols) or not syms:
|
||||
return ParamsFlow(
|
||||
arg_targets=self.arg_targets,
|
||||
kwarg_targets=self.kwarg_targets,
|
||||
func_args=self.func_args,
|
||||
func_kwargs=self.func_kwargs,
|
||||
symbols=self.symbols - syms,
|
||||
realized_symbols=self.realized_symbols | symbol_values,
|
||||
)
|
||||
msg = f'ParamsFlow: Not all partially realized symbols are defined on the ParamsFlow (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Generate ExprSocketDef
|
||||
####################
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
"""Declares `ManagedBLImage`."""
|
||||
|
||||
# import time
|
||||
import time
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
|
@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj):
|
|||
dpi: int | None = None,
|
||||
bl_select: bool = False,
|
||||
):
|
||||
# times = [time.perf_counter()]
|
||||
times = ['START', time.perf_counter()]
|
||||
|
||||
# Compute Plot Dimensions
|
||||
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||
|
@ -277,22 +277,22 @@ class ManagedBLImage(base.ManagedObj):
|
|||
# _width_inches, _height_inches, _dpi
|
||||
# )
|
||||
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
|
||||
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
||||
times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
||||
|
||||
# fig.clear()
|
||||
ax.clear()
|
||||
# times.append(['Clear Axis', time.perf_counter() - times[0]])
|
||||
times.append(['Clear Axis', time.perf_counter() - times[0]])
|
||||
|
||||
# Plot w/User Parameter
|
||||
func_plotter(ax)
|
||||
# times.append(['Plot!', time.perf_counter() - times[0]])
|
||||
times.append(['Plot!', time.perf_counter() - times[0]])
|
||||
|
||||
# Save Figure to BytesIO
|
||||
canvas.draw()
|
||||
# times.append(['Draw Pixels', time.perf_counter() - times[0]])
|
||||
times.append(['Draw Pixels', time.perf_counter() - times[0]])
|
||||
|
||||
canvas_width_px, canvas_height_px = fig.canvas.get_width_height()
|
||||
# times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
|
||||
times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
|
||||
image_data = (
|
||||
np.float32(
|
||||
np.flipud(
|
||||
|
@ -303,7 +303,7 @@ class ManagedBLImage(base.ManagedObj):
|
|||
)
|
||||
/ 255
|
||||
)
|
||||
# times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
|
||||
times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
|
||||
|
||||
# Optimized Write to Blender Image
|
||||
bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8')
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import queue
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
|
@ -26,6 +28,14 @@ from .managed_objs.managed_bl_image import ManagedBLImage
|
|||
|
||||
log = logger.get(__name__)
|
||||
|
||||
link_action_queue = queue.Queue()
|
||||
|
||||
|
||||
def set_link_validity(link: bpy.types.NodeLink, validity: bool) -> None:
|
||||
log.critical('Set %s validity to %s', str(link), str(validity))
|
||||
link.is_valid = validity
|
||||
|
||||
|
||||
####################
|
||||
# - Cache Management
|
||||
####################
|
||||
|
@ -45,58 +55,93 @@ class DeltaNodeLinkCache(typ.TypedDict):
|
|||
|
||||
|
||||
class NodeLinkCache:
|
||||
"""A pointer-based cache of node links in a node tree.
|
||||
"""A volatile pointer-based cache of node links in a node tree.
|
||||
|
||||
Warnings:
|
||||
Everything here is **extremely** unsafe.
|
||||
Even a single mistake **will** cause a use-after-free crash of Blender.
|
||||
|
||||
Used perfectly, it allows for powerful features; anything less, and it's an epic liability.
|
||||
|
||||
Attributes:
|
||||
_node_tree: Reference to the owning node tree.
|
||||
link_ptrs_as_links:
|
||||
link_ptrs: Pointers (as in integer memory adresses) to `NodeLink`s.
|
||||
link_ptrs_as_links: Map from pointers to actual `NodeLink`s.
|
||||
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the source of each `NodeLink`.
|
||||
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the destination of each `NodeLink`.
|
||||
_node_tree: Reference to the node tree for which this cache is valid.
|
||||
link_ptrs: Memory-address identifiers for all node links that currently exist in `_node_tree`.
|
||||
link_ptrs_as_links: Mapping from pointers (integers) to actual `NodeLink` objects.
|
||||
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
|
||||
socket_ptrs: Memory-address identifiers for all sockets that currently exist in `_node_tree`.
|
||||
socket_ptrs_as_sockets: Mapping from pointers (integers) to actual `NodeSocket` objects.
|
||||
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
|
||||
socket_ptr_refcount: The amount of links currently connected to a given socket pointer.
|
||||
Used to drive the deletion of socket pointers using only knowledge about `link_ptr` removal.
|
||||
link_ptrs_as_from_socket_ptrs: The pointer of the source socket, defined for every node link pointer.
|
||||
link_ptrs_as_to_socket_ptrs: The pointer of the destination socket, defined for every node link pointer.
|
||||
"""
|
||||
|
||||
def __init__(self, node_tree: bpy.types.NodeTree):
|
||||
"""Initialize the cache from a node tree.
|
||||
|
||||
Parameters:
|
||||
node_tree: The Blender node tree whose `NodeLink`s will be cached.
|
||||
"""
|
||||
"""Defines and fills the cache from a live node tree."""
|
||||
self._node_tree = node_tree
|
||||
|
||||
# Link PTR and PTR->REF
|
||||
self.link_ptrs: set[MemAddr] = set()
|
||||
self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {}
|
||||
|
||||
# Socket PTR and PTR->REF
|
||||
self.socket_ptrs: set[MemAddr] = set()
|
||||
self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {}
|
||||
self.socket_ptr_refcount: dict[MemAddr, int] = {}
|
||||
|
||||
# Link PTR -> Socket PTR
|
||||
self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
||||
self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
||||
|
||||
self.link_ptrs_invalid: set[MemAddr] = set()
|
||||
|
||||
# Fill Cache
|
||||
self.regenerate()
|
||||
|
||||
def remove_link(self, link_ptr: MemAddr) -> None:
|
||||
"""Removes a link pointer from the cache, indicating that the link doesn't exist anymore.
|
||||
"""Reports a link as removed, causing it to be removed from the cache.
|
||||
|
||||
This **must** be run whenever a node link is deleted.
|
||||
**Failure to to so WILL result in segmentation fault** at an unknown future time.
|
||||
|
||||
In particular, the following actions are taken:
|
||||
- The entry in `self.link_ptrs_as_links` is deleted.
|
||||
- Any entry in `self.link_ptrs_invalid` is deleted (if exists).
|
||||
|
||||
Notes:
|
||||
- **DOES NOT** remove PTR->REF dictionary entries
|
||||
- Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
|
||||
- This **must** be done whenever a node link is deleted.
|
||||
- Failure to do so may result in a segmentation fault at arbitrary future time.
|
||||
Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
|
||||
In some cases, this may be desirable, ex. for internal methods that shouldn't trip a `DataChanged` flow event.
|
||||
|
||||
Parameters:
|
||||
link_ptr: Pointer to remove from the cache.
|
||||
link_ptr: The pointer (integer) to remove from the cache.
|
||||
|
||||
Raises:
|
||||
KeyError: If `link_ptr` is not a member of either `self.link_ptrs`, or of `self.link_ptrs_as_links`.
|
||||
"""
|
||||
self.link_ptrs.remove(link_ptr)
|
||||
self.link_ptrs_as_links.pop(link_ptr)
|
||||
|
||||
if link_ptr in self.link_ptrs_invalid:
|
||||
self.link_ptrs_invalid.remove(link_ptr)
|
||||
|
||||
def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None:
|
||||
"""Removes a single pointer's reference to its from/to sockets."""
|
||||
"""Deassociate from all sockets referenced by a link, respecting the socket pointer reference-count.
|
||||
|
||||
The `NodeLinkCache` stores references to all socket pointers referenced by any link.
|
||||
Since several links can be associated with each socket, we must keep a "reference count" per-socket.
|
||||
When the "reference count" drops to zero, then there are no longer any `NodeLink`s that refer to it, and therefore it should be removed from the `NodeLinkCache`.
|
||||
|
||||
This method facilitates that process by:
|
||||
- Extracting (with removal) the from / to socket pointers associated with `link_ptr`.
|
||||
- If the socket pointer has a reference count of `1`, then it is **completely removed**.
|
||||
- If the socket pointer has a reference count of `>1`, then the reference count is decremented by `1`.
|
||||
|
||||
Notes:
|
||||
In general, this should be called together with `remove_link`.
|
||||
However, in certain cases, this process also needs to happen by itself.
|
||||
|
||||
Parameters:
|
||||
link_ptr: The pointer (integer) to remove from the cache.
|
||||
"""
|
||||
# Remove Socket Pointers
|
||||
from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None)
|
||||
to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None)
|
||||
|
||||
|
@ -113,31 +158,40 @@ class NodeLinkCache:
|
|||
self.socket_ptr_refcount[socket_ptr] -= 1
|
||||
|
||||
def regenerate(self) -> DeltaNodeLinkCache:
|
||||
"""Regenerates the cache from the internally-linked node tree.
|
||||
"""Efficiently scans the internally referenced node tree to thoroughly update all attributes of this `NodeLinkCache`.
|
||||
|
||||
Notes:
|
||||
- This is designed to run within the `update()` invocation of the node tree.
|
||||
- This should be a very fast function, since it is called so much.
|
||||
This runs in a **very** hot loop, within the `update()` function of the node tree.
|
||||
Anytime anything happens in the node tree, `update()` (and therefore this method) is called.
|
||||
|
||||
Thus, performance is of the utmost importance.
|
||||
Just a few microseconds too much may be amplified dozens of times over in practice, causing big stutters.
|
||||
"""
|
||||
# Compute All NodeLink Pointers
|
||||
## -> It can be very inefficient to do any full-scan of the node tree.
|
||||
## -> However, simply extracting the pointer: link ends up being fast.
|
||||
## -> This pattern seems to be the best we can do, efficiency-wise.
|
||||
all_link_ptrs_as_links = {
|
||||
link.as_pointer(): link for link in self._node_tree.links
|
||||
}
|
||||
all_link_ptrs = set(all_link_ptrs_as_links.keys())
|
||||
|
||||
# Compute Added/Removed Links
|
||||
## -> In essence, we've created a 'diff' here.
|
||||
## -> Set operations are fast, and expressive!
|
||||
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
||||
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
||||
|
||||
# Edge Case: 'from_socket' Reassignment
|
||||
## (Reverse engineered) When all:
|
||||
## - Created a new link between the same two nodes.
|
||||
## - Matching 'to_socket'.
|
||||
## - Non-matching 'from_socket' on the same node.
|
||||
## -> THEN the link_ptr will not change, but the from_socket ptr should.
|
||||
if len(added_link_ptrs) == 0 and len(removed_link_ptrs) == 0:
|
||||
## (Reverse Engineered) When all are true:
|
||||
## - Created a new link between the same nodes as previous link.
|
||||
## - Matching 'to_socket' as the previous link.
|
||||
## - Non-matching 'from_socket', but on the same node.
|
||||
## -> THEN the link_ptr will not change, but the from_socket ptr does.
|
||||
if not added_link_ptrs and not removed_link_ptrs:
|
||||
# Find the Link w/Reassigned 'from_socket' PTR
|
||||
## A bit of a performance hit from the search, but it's an edge case.
|
||||
## -> This isn't very fast, but the edge case isn't so common.
|
||||
## -> Comprehensions are still quite optimized.
|
||||
_link_ptr_as_from_socket_ptrs = {
|
||||
link_ptr: (
|
||||
from_socket_ptr,
|
||||
|
@ -149,9 +203,9 @@ class NodeLinkCache:
|
|||
}
|
||||
|
||||
# Completely Remove the Old Link (w/Reassigned 'from_socket')
|
||||
## This effectively reclassifies the edge case as a normal 're-add'.
|
||||
## -> Casts the edge case to look like a typical 're-add'.
|
||||
for link_ptr in _link_ptr_as_from_socket_ptrs:
|
||||
log.info(
|
||||
log.debug(
|
||||
'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s',
|
||||
link_ptr,
|
||||
)
|
||||
|
@ -159,21 +213,25 @@ class NodeLinkCache:
|
|||
self.remove_sockets_by_link_ptr(link_ptr)
|
||||
|
||||
# Recompute Added/Removed Links
|
||||
## The algorithm will now detect an "added link".
|
||||
## -> Guide the usual algorithm to detect an "added link".
|
||||
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
||||
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
||||
|
||||
# Shuffle Cache based on Change in Links
|
||||
## Remove Entries for Removed Pointers
|
||||
# Delete Removed Links
|
||||
## -> NOTE: We leave dangling socket information on purpose.
|
||||
## -> This information will be used to ask for 'removal consent'.
|
||||
## -> To truly remove, must call 'remove_socket_by_link_ptr' later.
|
||||
for removed_link_ptr in removed_link_ptrs:
|
||||
self.remove_link(removed_link_ptr)
|
||||
## User must manually call 'remove_socket_by_link_ptr' later.
|
||||
## For now, leave dangling socket information by-link.
|
||||
|
||||
# Add New Link Pointers
|
||||
# Create Added Links
|
||||
## -> First, simply concatenate the added link pointers.
|
||||
self.link_ptrs |= added_link_ptrs
|
||||
for link_ptr in added_link_ptrs:
|
||||
# Add Link PTR->REF
|
||||
# Create Pointer -> Reference Entry
|
||||
## -> This allows us to efficiently access the link by-pointer.
|
||||
## -> Doing so otherwise requires a full search.
|
||||
## -> **If link is deleted w/o report, access will cause crash**.
|
||||
new_link = all_link_ptrs_as_links[link_ptr]
|
||||
self.link_ptrs_as_links[link_ptr] = new_link
|
||||
|
||||
|
@ -183,34 +241,69 @@ class NodeLinkCache:
|
|||
to_socket = new_link.to_socket
|
||||
to_socket_ptr = to_socket.as_pointer()
|
||||
|
||||
# Add Socket PTR, PTR -> REF
|
||||
# Add Socket Information
|
||||
for socket_ptr, bl_socket in zip( # noqa: B905
|
||||
[from_socket_ptr, to_socket_ptr],
|
||||
[from_socket, to_socket],
|
||||
):
|
||||
# Increment RefCount of Socket PTR
|
||||
# RefCount > 0: Increment RefCount of Socket PTR
|
||||
## This happens if another link also uses the same socket.
|
||||
## 1. An output socket links to several inputs.
|
||||
## 2. A multi-input socket links from several inputs.
|
||||
if socket_ptr in self.socket_ptr_refcount:
|
||||
self.socket_ptr_refcount[socket_ptr] += 1
|
||||
|
||||
# RefCount == 0: Create Socket Pointer w/Reference
|
||||
## -> Also initialize the refcount for the socket pointer.
|
||||
else:
|
||||
## RefCount == 0: Add PTR, PTR -> REF
|
||||
self.socket_ptrs.add(socket_ptr)
|
||||
self.socket_ptrs_as_sockets[socket_ptr] = bl_socket
|
||||
self.socket_ptr_refcount[socket_ptr] = 1
|
||||
|
||||
# Add Link PTR -> Socket PTR
|
||||
# Add Entry from Link Pointer -> Socket Pointer
|
||||
self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr
|
||||
self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr
|
||||
|
||||
return {'added': added_link_ptrs, 'removed': removed_link_ptrs}
|
||||
|
||||
def update_validity(self) -> DeltaNodeLinkCache:
|
||||
"""Query all cached links to determine whether they are valid."""
|
||||
self.link_ptrs_invalid = {
|
||||
link_ptr for link_ptr, link in self.link_ptrs_as_links if not link.is_valid
|
||||
}
|
||||
|
||||
def report_validity(self, link_ptr: MemAddr, validity: bool) -> None:
|
||||
"""Report a link as invalid."""
|
||||
if validity and link_ptr in self.link_ptrs_invalid:
|
||||
self.link_ptrs_invalid.remove(link_ptr)
|
||||
elif not validity and link_ptr not in self.link_ptrs_invalid:
|
||||
self.link_ptrs_invalid.add(link_ptr)
|
||||
|
||||
def set_validities(self) -> None:
|
||||
"""Set the validity of links in the node tree according to the internal cache.
|
||||
|
||||
Validity doesn't need to be removed, as update() automatically cleans up by default.
|
||||
"""
|
||||
for link in [
|
||||
link
|
||||
for link_ptr, link in self.link_ptrs_as_links.items()
|
||||
if link_ptr in self.link_ptrs_invalid
|
||||
]:
|
||||
if link.is_valid:
|
||||
link.is_valid = False
|
||||
|
||||
|
||||
####################
|
||||
# - Node Tree Definition
|
||||
####################
|
||||
class MaxwellSimTree(bpy.types.NodeTree):
|
||||
"""Node tree containing a node-based program for design and analysis of Maxwell PDE simulations.
|
||||
|
||||
Attributes:
|
||||
is_active: Whether the node tree should be considered to be in a usable state, capable of updating Blender data.
|
||||
In general, only one `MaxwellSimTree` should be active at a time.
|
||||
"""
|
||||
|
||||
bl_idname = ct.TreeType.MaxwellSim.value
|
||||
bl_label = 'Maxwell Sim Editor'
|
||||
bl_icon = ct.Icon.SimNodeEditor
|
||||
|
@ -219,63 +312,6 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
|||
default=True,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Lock Methods
|
||||
####################
|
||||
def unlock_all(self) -> None:
|
||||
"""Unlock all nodes in the node tree, making them editable."""
|
||||
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
|
||||
for node in self.nodes:
|
||||
if node.type in ['REROUTE', 'FRAME']:
|
||||
continue
|
||||
node.locked = False
|
||||
for bl_socket in [*node.inputs, *node.outputs]:
|
||||
bl_socket.locked = False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def replot(self) -> None:
|
||||
self.is_currently_replotting = True
|
||||
self.something_plotted = False
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.is_currently_replotting = False
|
||||
if not self.something_plotted:
|
||||
ManagedBLImage.hide_preview()
|
||||
|
||||
def report_show_plot(self, node: bpy.types.Node) -> None:
|
||||
if hasattr(self, 'is_currently_replotting') and self.is_currently_replotting:
|
||||
self.something_plotted = True
|
||||
|
||||
@contextlib.contextmanager
|
||||
def repreview_all(self) -> None:
|
||||
all_nodes_with_preview_active = {
|
||||
node.instance_id: node
|
||||
for node in self.nodes
|
||||
if node.type not in ['REROUTE', 'FRAME'] and node.preview_active
|
||||
}
|
||||
self.is_currently_repreviewing = True
|
||||
self.newly_previewed_nodes = {}
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.is_currently_repreviewing = False
|
||||
for dangling_previewed_node in [
|
||||
node
|
||||
for node_instance_id, node in all_nodes_with_preview_active.items()
|
||||
if node_instance_id not in self.newly_previewed_nodes
|
||||
]:
|
||||
dangling_previewed_node.preview_active = False
|
||||
|
||||
def report_show_preview(self, node: bpy.types.Node) -> None:
|
||||
if (
|
||||
hasattr(self, 'is_currently_repreviewing')
|
||||
and self.is_currently_repreviewing
|
||||
):
|
||||
self.newly_previewed_nodes[node.instance_id] = node
|
||||
|
||||
####################
|
||||
# - Init Methods
|
||||
####################
|
||||
|
@ -290,7 +326,54 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
|||
self.node_link_cache = NodeLinkCache(self)
|
||||
|
||||
####################
|
||||
# - Update Methods
|
||||
# - Lock Methods
|
||||
####################
|
||||
def unlock_all(self) -> None:
|
||||
"""Unlock all nodes in the node tree, making them editable.
|
||||
|
||||
Notes:
|
||||
All `MaxwellSimNode`s have a `.locked` attribute, which prevents the entire UI from being modified.
|
||||
|
||||
This method simply sets the `locked` attribute to `False` on all nodes.
|
||||
"""
|
||||
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
|
||||
for node in self.nodes:
|
||||
if node.type in ['REROUTE', 'FRAME']:
|
||||
continue
|
||||
|
||||
# Unlock Node
|
||||
if node.locked:
|
||||
node.locked = False
|
||||
|
||||
# Unlock Node Sockets
|
||||
for bl_socket in [*node.inputs, *node.outputs]:
|
||||
if bl_socket.locked:
|
||||
bl_socket.locked = False
|
||||
|
||||
####################
|
||||
# - Link Update Methods
|
||||
####################
|
||||
def report_link_validity(self, link: bpy.types.NodeLink, validity: bool) -> None:
|
||||
"""Report that a particular `NodeLink` should be considered to be either valid or invalid.
|
||||
|
||||
The `NodeLink.is_valid` attribute is generally (and automatically) used to indicate the detection of cycles in the node tree.
|
||||
However, visually, it causes a very clear "error red" highlight to appear on the node link, which can extremely useful when determining the reasons behind unexpected outout.
|
||||
|
||||
Notes:
|
||||
Run by `MaxwellSimSocket` when a link should be shown to be "invalid".
|
||||
"""
|
||||
## TODO: Doesn't quite work.
|
||||
# log.debug(
|
||||
# 'Reported Link Validity %s (is_valid=%s, from_socket=%s, to_socket=%s)',
|
||||
# validity,
|
||||
# link.is_valid,
|
||||
# link.from_socket,
|
||||
# link.to_socket,
|
||||
# )
|
||||
# self.node_link_cache.report_validity(link.as_pointer(), validity)
|
||||
|
||||
####################
|
||||
# - Node Update Methods
|
||||
####################
|
||||
def on_node_removed(self, node: bpy.types.Node):
|
||||
"""Run by `MaxwellSimNode.free()` when a node is being removed.
|
||||
|
@ -327,32 +410,36 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
|||
self.node_link_cache.remove_link(link_ptr)
|
||||
self.node_link_cache.remove_sockets_by_link_ptr(link_ptr)
|
||||
|
||||
def update(self) -> None:
|
||||
def update(self) -> None: # noqa: PLR0912, C901
|
||||
"""Monitors all changes to the node tree, potentially responding with appropriate callbacks.
|
||||
|
||||
Notes:
|
||||
- Run by Blender when "anything" changes in the node tree.
|
||||
- Responds to node link changes with callbacks, with the help of a performant node link cache.
|
||||
"""
|
||||
# Perform Initial Load
|
||||
## -> Presume update() is run before the first link is altered.
|
||||
## -> Else, the first link of the session will not update caches.
|
||||
## -> We still remain slightly unsure of the exact semantics.
|
||||
## -> Therefore, self.on_load() is also called as a load_post handler.
|
||||
if not hasattr(self, 'node_link_cache'):
|
||||
self.on_load()
|
||||
return
|
||||
|
||||
# Register Validity Updater
|
||||
## -> They will be run after the update() method.
|
||||
## -> Between update() and set_validities, all is_valid=True are cleared.
|
||||
## -> Therefore, 'set_validities' only needs to set all is_valid=False.
|
||||
bpy.app.timers.register(self.node_link_cache.set_validities)
|
||||
|
||||
# Ignore Updates
|
||||
## -> Certain corrective processes require suppressing the next update.
|
||||
## -> Otherwise, link corrections may trigger some nasty recursions.
|
||||
if not hasattr(self, 'ignore_update'):
|
||||
self.ignore_update = False
|
||||
|
||||
if not hasattr(self, 'node_link_cache'):
|
||||
self.on_load()
|
||||
## We presume update() is run before the first link is altered.
|
||||
## - Else, the first link of the session will not update caches.
|
||||
## - We remain slightly unsure of the semantics.
|
||||
## - Therefore, self.on_load() is also called as a load_post handler.
|
||||
return
|
||||
|
||||
# Ignore Update
|
||||
## Manually set to implement link corrections w/o recursion.
|
||||
if self.ignore_update:
|
||||
return
|
||||
|
||||
# Compute Changes to Node Links
|
||||
# Regenerate NodeLinkCache
|
||||
delta_links = self.node_link_cache.regenerate()
|
||||
|
||||
link_corrections = {
|
||||
'to_remove': [],
|
||||
'to_add': [],
|
||||
|
|
|
@ -358,6 +358,11 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
|
||||
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
|
||||
|
||||
# Extract Info
|
||||
## -> We only need the output symbol.
|
||||
## -> All labelled outputs have the same output SimSymbol.
|
||||
info = extract_info(monitor_data, idx_labels[0])
|
||||
|
||||
# Generate FuncFlow Per Index Label
|
||||
## -> We extract each XArray as an attribute of monitor_data.
|
||||
## -> We then bind its values into a unique func_flow.
|
||||
|
@ -377,7 +382,8 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
## -> Then, 'compose_within' lets us stack them along axis=0.
|
||||
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
|
||||
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
|
||||
enclosing_func=lambda data: jnp.stack(data, axis=0)
|
||||
lambda data: jnp.stack(data, axis=0),
|
||||
func_output=info.output,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -65,12 +65,12 @@ class FilterOperation(enum.StrEnum):
|
|||
FO = FilterOperation
|
||||
return {
|
||||
# Slice
|
||||
FO.Slice: '=a[i:j]',
|
||||
FO.SliceIdx: '≈a[v₁:v₂]',
|
||||
FO.Slice: '≈a[v₁:v₂]',
|
||||
FO.SliceIdx: '=a[i:j]',
|
||||
# Pin
|
||||
FO.PinLen1: 'pinₐ',
|
||||
FO.Pin: 'pinₐ ≈v',
|
||||
FO.PinIdx: 'pinₐ =i',
|
||||
FO.PinLen1: 'a[0] → a',
|
||||
FO.Pin: 'a[v] ⇝ a',
|
||||
FO.PinIdx: 'a[i] → a',
|
||||
# Reinterpret
|
||||
FO.Swap: 'a₁ ↔ a₂',
|
||||
}[value]
|
||||
|
@ -517,6 +517,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
return lazy_func.compose_within(
|
||||
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
|
||||
enclosing_func_args=operation.func_args,
|
||||
enclosing_func_output=info.output,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -547,22 +547,30 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
# Loaded
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.Func,
|
||||
},
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
def compute_func(
|
||||
self, props, input_sockets, output_sockets
|
||||
) -> ct.FuncFlow | ct.FlowSignal:
|
||||
expr = input_sockets['Expr']
|
||||
output_info = output_sockets['Expr']
|
||||
|
||||
has_expr = not ct.FlowSignal.check(expr)
|
||||
has_output_info = not ct.FlowSignal.check(output_info)
|
||||
|
||||
operation = props['operation']
|
||||
if has_expr and operation is not None:
|
||||
return expr.compose_within(
|
||||
operation.jax_func,
|
||||
enclosing_func_output=output_info.output,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -146,7 +146,6 @@ class BinaryOperation(enum.StrEnum):
|
|||
outl = info_l.output
|
||||
outr = info_r.output
|
||||
match (outl.shape_len, outr.shape_len):
|
||||
# match (ol.shape_len, info_r.output.shape_len):
|
||||
# Number | *
|
||||
## Number | Number
|
||||
case (0, 0):
|
||||
|
@ -154,15 +153,25 @@ class BinaryOperation(enum.StrEnum):
|
|||
BO.Add,
|
||||
BO.Sub,
|
||||
BO.Mul,
|
||||
BO.Div,
|
||||
BO.Pow,
|
||||
]
|
||||
|
||||
# Check Non-Zero Right Hand Side
|
||||
## -> Obviously, we can't ever divide by zero.
|
||||
## -> Sympy's assumptions system must always guarantee rhs != 0.
|
||||
## -> If it can't, then we simply don't expose division.
|
||||
## -> The is_zero assumption must be provided elsewhere.
|
||||
## -> NOTE: This may prevent some valid uses of division.
|
||||
## -> Watch out for "division is missing" bugs.
|
||||
if info_r.output.is_nonzero:
|
||||
ops.append(BO.Div)
|
||||
|
||||
if (
|
||||
info_l.output.physical_type == spux.PhysicalType.Length
|
||||
and info_l.output.unit == info_r.output.unit
|
||||
):
|
||||
ops += [BO.Atan2]
|
||||
return ops
|
||||
|
||||
return [*ops, BO.Pow]
|
||||
|
||||
## Number | Vector
|
||||
case (0, 1):
|
||||
|
@ -336,7 +345,13 @@ class BinaryOperation(enum.StrEnum):
|
|||
# - InfoFlow Transform
|
||||
####################
|
||||
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
||||
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
|
||||
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
|
||||
|
||||
Warnings:
|
||||
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
|
||||
|
||||
If not, bad things will happen.
|
||||
"""
|
||||
return info_l.operate_output(
|
||||
info_r,
|
||||
lambda a, b: self.sp_func([a, b]),
|
||||
|
@ -479,29 +494,35 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'operation'},
|
||||
input_sockets={'Expr L', 'Expr R'},
|
||||
input_socket_kinds={
|
||||
'Expr L': ct.FlowKind.Func,
|
||||
'Expr R': ct.FlowKind.Func,
|
||||
},
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def compose_func(self, props: dict, input_sockets: dict):
|
||||
def compute_func(self, props, input_sockets, output_sockets):
|
||||
operation = props['operation']
|
||||
if operation is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
expr_l = input_sockets['Expr L']
|
||||
expr_r = input_sockets['Expr R']
|
||||
output_info = output_sockets['Expr']
|
||||
|
||||
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||
has_expr_r = not ct.FlowSignal.check(expr_r)
|
||||
has_output_info = not ct.FlowSignal.check(output_info)
|
||||
|
||||
# Compute Jax Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_l and has_expr_r:
|
||||
if has_expr_l and has_expr_r and has_output_info:
|
||||
return (expr_l | expr_r).compose_within(
|
||||
enclosing_func=operation.jax_func,
|
||||
operation.jax_func,
|
||||
enclosing_func_output=output_info.output,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
@ -520,6 +541,8 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
},
|
||||
)
|
||||
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
|
||||
BO = BinaryOperation
|
||||
|
||||
operation = props['operation']
|
||||
info_l = input_sockets['Expr L']
|
||||
info_r = input_sockets['Expr R']
|
||||
|
@ -533,7 +556,7 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
has_info_l
|
||||
and has_info_r
|
||||
and operation is not None
|
||||
and operation in BinaryOperation.by_infos(info_l, info_r)
|
||||
and operation in BO.by_infos(info_l, info_r)
|
||||
):
|
||||
return operation.transform_infos(info_l, info_r)
|
||||
|
||||
|
|
|
@ -606,27 +606,32 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
|
||||
},
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||
def compute_func(
|
||||
self, props, input_sockets, output_sockets
|
||||
) -> ct.FuncFlow | ct.FlowSignal:
|
||||
"""Transform the input `InfoFlow` depending on the transform operation."""
|
||||
TO = TransformOperation
|
||||
operation = props['operation']
|
||||
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
output_info = output_sockets['Expr']
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||
has_output_info = not ct.FlowSignal.check(output_info)
|
||||
|
||||
if operation is not None and has_lazy_func and has_info:
|
||||
# Retrieve Properties
|
||||
operation = props['operation']
|
||||
if operation is not None and has_lazy_func and has_info and has_output_info:
|
||||
dim = props['dim']
|
||||
|
||||
# Match Pattern by Operation
|
||||
match operation:
|
||||
case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D:
|
||||
if dim is not None and info.has_idx_discrete(dim):
|
||||
return lazy_func.compose_within(
|
||||
operation.jax_func(axis=info.dim_axis(dim)),
|
||||
enclosing_func_output=output_info.output,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
@ -634,6 +639,7 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
case _:
|
||||
return lazy_func.compose_within(
|
||||
operation.jax_func(),
|
||||
enclosing_func_output=output_info.output,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -406,7 +406,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
},
|
||||
all_loose_input_sockets=True,
|
||||
)
|
||||
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
|
||||
def compute_previews(self, props, input_sockets, loose_input_sockets):
|
||||
"""Needed for the plot to regenerate in the viewer."""
|
||||
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
|
||||
|
||||
|
@ -433,7 +433,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
def on_show_plot(
|
||||
self, managed_objs, props, input_sockets, loose_input_sockets
|
||||
) -> None:
|
||||
log.critical('Show Plot (too many times)')
|
||||
log.debug('Show Plot')
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
@ -456,6 +456,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
|
||||
},
|
||||
)
|
||||
## TODO: CACHE entries that don't change, PLEASEEE
|
||||
|
||||
# Match Viz Type & Perform Visualization
|
||||
## -> Viz Target determines how to plot.
|
||||
|
|
|
@ -207,12 +207,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
stop_propagation=True,
|
||||
)
|
||||
def _on_sim_node_name_changed(self, props):
|
||||
log.debug(
|
||||
'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
|
||||
self.bl_idname,
|
||||
props['sim_node_name'],
|
||||
str(self),
|
||||
)
|
||||
# log.debug(
|
||||
# 'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
|
||||
# self.bl_idname,
|
||||
# props['sim_node_name'],
|
||||
# str(self),
|
||||
# )
|
||||
|
||||
# (Re)Construct Managed Objects
|
||||
## -> Due to 'prev_name', the new MObjs will be renamed on construction
|
||||
|
@ -360,27 +360,48 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
####################
|
||||
# - Socket Management
|
||||
####################
|
||||
## TODO: Check for namespace collisions in sockets to prevent silent errors
|
||||
def _prune_inactive_sockets(self):
|
||||
"""Remove all "inactive" sockets from the node.
|
||||
"""Remove all inactive sockets from the node, while only updating sockets that can be non-destructively updated.
|
||||
|
||||
A socket is considered "inactive" when it shouldn't be defined (per `self.active_socket_defs), but is present nonetheless.
|
||||
The first step is easy: We determine, by-name, which sockets should no longer be defined, then remove them correctly.
|
||||
|
||||
The second step is harder: When new sockets have overlapping names, should they be removed, or should they merely have some properties updated?
|
||||
Removing and re-adding the same socket is an accurate, generally robust approach, but it comes with a big caveat: **Existing node links will be cut**, even when it might semantically make sense to simply alter the socket's properties, keeping the links.
|
||||
|
||||
Different `bl_socket.socket_type`s can never be updated - they must be removed.
|
||||
Otherwise, `SocketDef.compare(bl_socket)` allows us to granularly determine whether a particular `bl_socket` has changed with respect to the desired specification.
|
||||
When the comparison is `False`, we can carefully utilize `SocketDef.init()` to re-initialize the socket, guaranteeing that the altered socket is up to the new specification.
|
||||
"""
|
||||
node_tree = self.id_data
|
||||
for direc in ['input', 'output']:
|
||||
all_bl_sockets = self._bl_sockets(direc)
|
||||
active_bl_socket_defs = self.active_socket_defs(direc)
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
active_socket_defs = self.active_socket_defs(direc)
|
||||
|
||||
# Determine Sockets to Remove
|
||||
## -> Name: If the existing socket name isn't "active".
|
||||
## -> Type: If the existing socket_type != "active" SocketDef.
|
||||
bl_sockets_to_remove = [
|
||||
bl_socket
|
||||
for socket_name, bl_socket in all_bl_sockets.items()
|
||||
if socket_name not in active_bl_socket_defs
|
||||
or socket_name
|
||||
in (
|
||||
self.loose_input_sockets
|
||||
if direc == 'input'
|
||||
else self.loose_output_sockets
|
||||
for socket_name, bl_socket in bl_sockets.items()
|
||||
if (
|
||||
socket_name not in active_socket_defs
|
||||
or bl_socket.socket_type
|
||||
is not active_socket_defs[socket_name].socket_type
|
||||
)
|
||||
]
|
||||
|
||||
# Determine Sockets to Update
|
||||
## -> Name: If the existing socket name is "active".
|
||||
## -> Type: If the existing socket_type == "active" SocketDef.
|
||||
## -> Compare: If the existing socket differs from the SocketDef.
|
||||
bl_sockets_to_update = [
|
||||
bl_socket
|
||||
for socket_name, bl_socket in bl_sockets.items()
|
||||
if (
|
||||
socket_name in active_socket_defs
|
||||
and bl_socket.socket_type
|
||||
is active_socket_defs[socket_name].socket_type
|
||||
and not active_socket_defs[socket_name].compare(bl_socket)
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -392,24 +413,25 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
## -> The NodeLinkCache needs to be adjusted manually.
|
||||
node_tree.on_node_socket_removed(bl_socket)
|
||||
|
||||
# 2. Invalidate the input socket cache across all kinds.
|
||||
# 2. Perform the removal using Blender's API.
|
||||
## -> Actually removes the socket.
|
||||
bl_sockets.remove(bl_socket)
|
||||
|
||||
# 3. Invalidate the input socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
## -> Done after socket removal to protect from race condition.
|
||||
self._compute_input.invalidate(
|
||||
input_socket_name=bl_socket_name,
|
||||
kind=...,
|
||||
unit_system=...,
|
||||
)
|
||||
|
||||
# 3. Perform the removal using Blender's API.
|
||||
## -> Actually removes the socket.
|
||||
all_bl_sockets.remove(bl_socket)
|
||||
|
||||
if direc == 'input':
|
||||
# 4. Run all trigger-only `on_value_changed` callbacks.
|
||||
## -> Runs any event methods that relied on the socket.
|
||||
## -> Only methods that don't **require** the socket.
|
||||
## Trigger-Only: If method loads no socket data, it runs.
|
||||
## `optional`: If method optional-loads socket, it runs.
|
||||
## Only Trigger: If method loads no socket data, it runs.
|
||||
## Optional: If method optional-loads socket, it runs.
|
||||
triggered_event_methods = [
|
||||
event_method
|
||||
for event_method in self.filtered_event_methods_by_event(
|
||||
|
@ -419,32 +441,52 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
not in event_method.callback_info.must_load_sockets
|
||||
]
|
||||
for event_method in triggered_event_methods:
|
||||
log.critical(
|
||||
'%s: Running %s',
|
||||
self.sim_node_name,
|
||||
str(event_method),
|
||||
)
|
||||
event_method(self)
|
||||
|
||||
# Update Sockets
|
||||
for bl_socket in bl_sockets_to_update:
|
||||
bl_socket_name = bl_socket.name
|
||||
socket_def = active_socket_defs[bl_socket_name]
|
||||
|
||||
# 1. Pretend to Initialize for the First Time
|
||||
## -> NOTE: The socket's caches will be completely regenerated.
|
||||
## -> NOTE: A full FlowKind update will occur, but only one.
|
||||
bl_socket.is_initializing = True
|
||||
socket_def.preinit(bl_socket)
|
||||
socket_def.init(bl_socket)
|
||||
socket_def.postinit(bl_socket)
|
||||
|
||||
# 2. Re-Test Socket Capabilities
|
||||
## -> Factors influencing CapabilitiesFlow may have changed.
|
||||
## -> Therefore, we must re-test all link capabilities.
|
||||
bl_socket.remove_invalidated_links()
|
||||
|
||||
# 3. Invalidate the input socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
self._compute_input.invalidate(
|
||||
input_socket_name=bl_socket_name,
|
||||
kind=...,
|
||||
unit_system=...,
|
||||
)
|
||||
|
||||
def _add_new_active_sockets(self):
|
||||
"""Add and initialize all "active" sockets that aren't on the node.
|
||||
|
||||
Existing sockets within the given direction are not re-created.
|
||||
"""
|
||||
for direc in ['input', 'output']:
|
||||
all_bl_sockets = self._bl_sockets(direc)
|
||||
active_bl_socket_defs = self.active_socket_defs(direc)
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
active_socket_defs = self.active_socket_defs(direc)
|
||||
|
||||
# Define BL Sockets
|
||||
created_sockets = {}
|
||||
for socket_name, socket_def in active_bl_socket_defs.items():
|
||||
for socket_name, socket_def in active_socket_defs.items():
|
||||
# Skip Existing Sockets
|
||||
if socket_name in all_bl_sockets:
|
||||
if socket_name in bl_sockets:
|
||||
continue
|
||||
|
||||
# Create BL Socket from Socket
|
||||
## Set 'display_shape' from 'socket_shape'
|
||||
all_bl_sockets.new(
|
||||
bl_sockets.new(
|
||||
str(socket_def.socket_type.value),
|
||||
socket_name,
|
||||
)
|
||||
|
@ -454,9 +496,9 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
|
||||
# Initialize Just-Created BL Sockets
|
||||
for bl_socket_name, socket_def in created_sockets.items():
|
||||
socket_def.preinit(all_bl_sockets[bl_socket_name])
|
||||
socket_def.init(all_bl_sockets[bl_socket_name])
|
||||
socket_def.postinit(all_bl_sockets[bl_socket_name])
|
||||
socket_def.preinit(bl_sockets[bl_socket_name])
|
||||
socket_def.init(bl_sockets[bl_socket_name])
|
||||
socket_def.postinit(bl_sockets[bl_socket_name])
|
||||
|
||||
# Invalidate Cached NoFlows
|
||||
self._compute_input.invalidate(
|
||||
|
@ -637,9 +679,10 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
lambda a, b: a | b,
|
||||
[
|
||||
self._compute_input(
|
||||
socket, kind=ct.FlowKind.Previews, unit_system=None
|
||||
socket_name,
|
||||
kind=ct.FlowKind.Previews,
|
||||
)
|
||||
for socket in [bl_socket.name for bl_socket in self.inputs]
|
||||
for socket_name in [bl_socket.name for bl_socket in self.inputs]
|
||||
],
|
||||
ct.PreviewsFlow(),
|
||||
)
|
||||
|
@ -897,9 +940,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
)
|
||||
altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
|
||||
|
||||
# Clear Output Socket Cache(s)
|
||||
## -> We aggregate it manually, so it needs a special invl.
|
||||
## -> See self.compute_output()
|
||||
if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
|
||||
for out_sckname in self.outputs.keys(): # noqa: SIM118
|
||||
self.compute_output.invalidate(
|
||||
output_socket_name=out_sckname,
|
||||
kind=ct.FlowKind.Previews,
|
||||
)
|
||||
altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
|
||||
|
||||
# Run Triggered Event Methods
|
||||
## -> A triggered event method may request to stop propagation.
|
||||
## -> A triggered event method may request to stop propagation.
|
||||
stop_propagation = False
|
||||
triggered_event_methods = self.filtered_event_methods_by_event(
|
||||
event, (socket_name, prop_names, None)
|
||||
|
|
|
@ -266,7 +266,7 @@ def event_decorator( # noqa: PLR0913
|
|||
)
|
||||
|
||||
# Loose Sockets
|
||||
## Compute All Loose Input Sockets
|
||||
## -> Determined by the active_kind of each loose input socket.
|
||||
method_kw_args |= (
|
||||
{
|
||||
'loose_input_sockets': {
|
||||
|
|
|
@ -29,6 +29,8 @@ from ... import base, events
|
|||
|
||||
|
||||
class ScientificConstantNode(base.MaxwellSimNode):
|
||||
"""A well-known constant usable as itself, or as a symbol."""
|
||||
|
||||
node_type = ct.NodeType.ScientificConstant
|
||||
bl_label = 'Scientific Constant'
|
||||
|
||||
|
@ -88,6 +90,11 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
if self.sci_constant_str:
|
||||
return f'Const: {self.sci_constant_str}'
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||
col.prop(self, self.blfields['sci_constant_str'], text='')
|
||||
|
||||
|
@ -156,6 +163,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
props={'sci_constant', 'sci_constant_sym'},
|
||||
)
|
||||
def compute_lazy_func(self, props) -> typ.Any:
|
||||
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
|
||||
sci_constant = props['sci_constant']
|
||||
sci_constant_sym = props['sci_constant_sym']
|
||||
|
||||
|
@ -165,6 +173,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
|
||||
),
|
||||
func_args=[sci_constant_sym],
|
||||
func_output=sci_constant_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
@ -175,6 +184,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
props={'sci_constant_sym'},
|
||||
)
|
||||
def compute_info(self, props: dict) -> typ.Any:
|
||||
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
|
||||
sci_constant_sym = props['sci_constant_sym']
|
||||
|
||||
if sci_constant_sym is not None:
|
||||
|
@ -193,8 +203,12 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
if sci_constant is not None and sci_constant_sym is not None:
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=[sci_constant_sym],
|
||||
func_args=[sci_constant],
|
||||
is_differentiable=True,
|
||||
func_args=[sci_constant_sym.sp_symbol],
|
||||
symbols={sci_constant_sym},
|
||||
).realize_partial(
|
||||
{
|
||||
sci_constant_sym: sci_constant,
|
||||
}
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
|
|
@ -216,10 +216,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
props={'symbol'},
|
||||
)
|
||||
def compute_lazy_func(self, props) -> typ.Any:
|
||||
sp_sym = props['symbol'].sp_symbol
|
||||
sym = props['symbol']
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(sp_sym, sp_sym, 'jax'),
|
||||
func_args=[sp_sym],
|
||||
func=sp.lambdify(sym.sp_symbol_matsym, sym.sp_symbol_matsym, 'jax'),
|
||||
func_args=[sym],
|
||||
func_output=sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
@ -235,6 +236,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
)
|
||||
def compute_info(self, props) -> typ.Any:
|
||||
return ct.InfoFlow(
|
||||
dims={props['symbol']: None},
|
||||
output=props['symbol'],
|
||||
)
|
||||
|
||||
|
@ -251,9 +253,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
arg_targets=[sym],
|
||||
func_args=[sym.sp_symbol],
|
||||
symbols={sym},
|
||||
is_differentiable=(
|
||||
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -198,9 +198,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
'Expr',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
|
||||
input_sockets={'File Path'},
|
||||
)
|
||||
def compute_func(self, input_sockets) -> td.Simulation:
|
||||
def compute_func(self, props, input_sockets) -> td.Simulation:
|
||||
"""Declare a lazy, composable function that returns the loaded data.
|
||||
|
||||
Returns:
|
||||
|
@ -209,6 +210,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
file_path = input_sockets['File Path']
|
||||
has_file_path = not ct.FlowSignal.check(file_path)
|
||||
|
||||
func_output = sim_symbols.SimSymbol(
|
||||
sym_name=props['output_name'],
|
||||
mathtype=props['output_mathtype'],
|
||||
physical_type=props['output_physical_type'],
|
||||
unit=props['output_unit'],
|
||||
)
|
||||
if has_file_path and file_path is not None:
|
||||
data_file_format = ct.DataFileFormat.from_path(file_path)
|
||||
if data_file_format is not None:
|
||||
|
@ -217,13 +224,18 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
if data_file_format.loader_is_jax_compatible:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: data_file_format.loader(file_path),
|
||||
func_output=func_output,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
# No Jax Compatibility: Eager Data Loading
|
||||
## -> Load the data now and bind it.
|
||||
data = data_file_format.loader(file_path)
|
||||
return ct.FuncFlow(func=lambda: data, supports_jax=True)
|
||||
return ct.FuncFlow(
|
||||
func=lambda: data,
|
||||
func_output=func_output,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
|
|
@ -86,22 +86,45 @@ class ViewerNode(base.MaxwellSimNode):
|
|||
# - Properties: Computed FlowKinds
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='Any',
|
||||
# Trigger
|
||||
prop_name='console_print_kind',
|
||||
# Loaded
|
||||
props={'auto_expr', 'console_print_kind'},
|
||||
)
|
||||
def on_input_changed(self) -> None:
|
||||
def on_print_kind_changed(self, props) -> None:
|
||||
self.inputs['Any'].active_kind = props['console_print_kind']
|
||||
|
||||
if props['auto_expr']:
|
||||
setattr(
|
||||
self,
|
||||
'input_' + props['console_print_kind'].property_name,
|
||||
bl_cache.Signal.InvalidateCache,
|
||||
)
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trugger
|
||||
socket_name='Any',
|
||||
# Loaded
|
||||
props={'auto_expr', 'console_print_kind'},
|
||||
)
|
||||
def on_input_changed(self, props) -> None:
|
||||
"""Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed.
|
||||
|
||||
Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s.
|
||||
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
|
||||
"""
|
||||
for flow_kind in list(ct.FlowKind):
|
||||
flow = self.get_flow(
|
||||
flow_kind, always_load=flow_kind is ct.FlowKind.Previews
|
||||
)
|
||||
if flow is not None:
|
||||
# Invalidate PreviewsFlow
|
||||
setattr(
|
||||
self,
|
||||
'input_' + flow_kind.property_name,
|
||||
'input_' + ct.FlowKind.Previews.property_name,
|
||||
bl_cache.Signal.InvalidateCache,
|
||||
)
|
||||
|
||||
# Invalidate PreviewsFlow
|
||||
if props['auto_expr']:
|
||||
setattr(
|
||||
self,
|
||||
'input_' + props['console_print_kind'].property_name,
|
||||
bl_cache.Signal.InvalidateCache,
|
||||
)
|
||||
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import bl_cache
|
||||
|
@ -26,6 +28,8 @@ from .. import base, events
|
|||
|
||||
|
||||
class CombineNode(base.MaxwellSimNode):
|
||||
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
|
||||
|
||||
node_type = ct.NodeType.Combine
|
||||
bl_label = 'Combine'
|
||||
|
||||
|
@ -33,112 +37,222 @@ class CombineNode(base.MaxwellSimNode):
|
|||
# - Sockets
|
||||
####################
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Maxwell Sources': {},
|
||||
'Maxwell Structures': {},
|
||||
'Maxwell Monitors': {},
|
||||
'Sources': {},
|
||||
'Structures': {},
|
||||
'Monitors': {},
|
||||
}
|
||||
output_socket_sets: typ.ClassVar = {
|
||||
'Maxwell Sources': {
|
||||
'Sources': {
|
||||
'Sources': sockets.MaxwellSourceSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
},
|
||||
'Maxwell Structures': {
|
||||
'Structures': {
|
||||
'Structures': sockets.MaxwellStructureSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
},
|
||||
'Maxwell Monitors': {
|
||||
'Monitors': {
|
||||
'Monitors': sockets.MaxwellMonitorSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
####################
|
||||
# - Draw
|
||||
# - Properties
|
||||
####################
|
||||
amount: int = bl_cache.BLField(2, abs_min=1, prop_ui=True)
|
||||
concatenate_first: bool = bl_cache.BLField(False)
|
||||
value_or_func: ct.FlowKind = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self._value_or_func(),
|
||||
)
|
||||
|
||||
def _value_or_func(self):
|
||||
return [
|
||||
flow_kind.bl_enum_element(i)
|
||||
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
|
||||
]
|
||||
|
||||
####################
|
||||
# - Draw
|
||||
####################
|
||||
def draw_props(self, context, layout):
|
||||
layout.prop(self, self.blfields['amount'], text='')
|
||||
def draw_props(self, _, layout: bpy.types.UILayout):
|
||||
layout.prop(self, self.blfields['value_or_func'], text='')
|
||||
|
||||
if self.value_or_func is ct.FlowKind.Value:
|
||||
layout.prop(
|
||||
self,
|
||||
self.blfields['concatenate_first'],
|
||||
text='Concatenate',
|
||||
toggle=True,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
prop_name={'active_socket_set', 'amount'},
|
||||
props={'active_socket_set', 'amount'},
|
||||
any_loose_input_socket=True,
|
||||
prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
|
||||
)
|
||||
def on_inputs_changed(self, props):
|
||||
if props['active_socket_set'] == 'Maxwell Sources':
|
||||
if (
|
||||
not self.loose_input_sockets
|
||||
or not next(iter(self.loose_input_sockets)).startswith('Source')
|
||||
or len(self.loose_input_sockets) != props['amount']
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
f'Source #{i}': sockets.MaxwellSourceSocketDef()
|
||||
for i in range(props['amount'])
|
||||
}
|
||||
def on_inputs_changed(self, props) -> None:
|
||||
"""Always create one extra loose input socket."""
|
||||
active_socket_set = props['active_socket_set']
|
||||
|
||||
# Deduce SocketDef
|
||||
## -> Cheat by retrieving the class from the output sockets.
|
||||
SocketDef = self.output_socket_sets[active_socket_set][
|
||||
active_socket_set
|
||||
].__class__
|
||||
|
||||
# Deduce Current "Filled"
|
||||
## -> The first linked socket from the end bounds the "filled" region.
|
||||
## -> The length of that region, plus one, will be the new amount.
|
||||
reverse_linked_idxs = [
|
||||
i
|
||||
for i, bl_socket in enumerate(reversed(self.inputs.values()))
|
||||
if bl_socket.is_linked
|
||||
]
|
||||
current_filled = len(self.inputs) - (
|
||||
reverse_linked_idxs[0] if reverse_linked_idxs else len(self.inputs)
|
||||
)
|
||||
new_amount = current_filled + 1
|
||||
|
||||
# Deduce SocketDef | Current Amount
|
||||
concatenate_first = props['concatenate_first']
|
||||
flow_kind = props['value_or_func']
|
||||
|
||||
elif props['active_socket_set'] == 'Maxwell Structures':
|
||||
if (
|
||||
not self.loose_input_sockets
|
||||
or not next(iter(self.loose_input_sockets)).startswith('Structure')
|
||||
or len(self.loose_input_sockets) != props['amount']
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
f'Structure #{i}': sockets.MaxwellStructureSocketDef()
|
||||
for i in range(props['amount'])
|
||||
}
|
||||
elif props['active_socket_set'] == 'Maxwell Monitors':
|
||||
if (
|
||||
not self.loose_input_sockets
|
||||
or not next(iter(self.loose_input_sockets)).startswith('Monitor')
|
||||
or len(self.loose_input_sockets) != props['amount']
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
f'Monitor #{i}': sockets.MaxwellMonitorSocketDef()
|
||||
for i in range(props['amount'])
|
||||
}
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
'#0': SocketDef(
|
||||
active_kind=flow_kind
|
||||
if flow_kind is ct.FlowKind.Func or not concatenate_first
|
||||
else ct.FlowKind.Array
|
||||
)
|
||||
} | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
|
||||
|
||||
####################
|
||||
# - Output Socket Computation
|
||||
# - FlowKind.Array|Func
|
||||
####################
|
||||
def compute_combined(
|
||||
self,
|
||||
loose_input_sockets,
|
||||
input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
|
||||
output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
|
||||
) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
|
||||
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
|
||||
|
||||
If there is no output, or the flows aren't compatible, return `FlowPending`.
|
||||
"""
|
||||
match (input_flow_kind, output_flow_kind):
|
||||
case (ct.FlowKind.Value, ct.FlowKind.Array):
|
||||
value_flows = [
|
||||
inp
|
||||
for inp in loose_input_sockets.values()
|
||||
if not ct.FlowSignal.check(inp)
|
||||
]
|
||||
if value_flows:
|
||||
return value_flows
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
case (ct.FlowKind.Func, ct.FlowKind.Func):
|
||||
func_flows = [
|
||||
inp
|
||||
for inp in loose_input_sockets.values()
|
||||
if not ct.FlowSignal.check(inp)
|
||||
]
|
||||
if func_flows:
|
||||
return functools.reduce(
|
||||
lambda a, b: a | b,
|
||||
func_flows,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Output: Sources
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Sources',
|
||||
kind=ct.FlowKind.Array,
|
||||
all_loose_input_sockets=True,
|
||||
props={'amount'},
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_sources_array(
|
||||
self, props, loose_input_sockets
|
||||
) -> list[typ.Any] | ct.FlowSignal:
|
||||
"""Compute sources."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||
)
|
||||
def compute_sources(self, loose_input_sockets, props) -> sp.Expr:
|
||||
return [loose_input_sockets[f'Source #{i}'] for i in range(props['amount'])]
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Sources',
|
||||
kind=ct.FlowKind.Func,
|
||||
all_loose_input_sockets=True,
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||
"""Compute (lazy) sources."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||
)
|
||||
|
||||
####################
|
||||
# - Output: Structures
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Structures',
|
||||
kind=ct.FlowKind.Array,
|
||||
all_loose_input_sockets=True,
|
||||
props={'amount'},
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
|
||||
"""Compute structures."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||
)
|
||||
def compute_structures(self, loose_input_sockets, props) -> sp.Expr:
|
||||
return [loose_input_sockets[f'Structure #{i}'] for i in range(props['amount'])]
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Structures',
|
||||
kind=ct.FlowKind.Func,
|
||||
all_loose_input_sockets=True,
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||
"""Compute (lazy) structures."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||
)
|
||||
|
||||
####################
|
||||
# - Output: Monitors
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Monitors',
|
||||
kind=ct.FlowKind.Array,
|
||||
all_loose_input_sockets=True,
|
||||
props={'amount'},
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
|
||||
"""Compute monitors."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||
)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Monitors',
|
||||
kind=ct.FlowKind.Func,
|
||||
all_loose_input_sockets=True,
|
||||
props={'value_or_func'},
|
||||
)
|
||||
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||
"""Compute (lazy) monitors."""
|
||||
return self.compute_combined(
|
||||
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||
)
|
||||
def compute_monitors(self, loose_input_sockets, props) -> sp.Expr:
|
||||
return [loose_input_sockets[f'Monitor #{i}'] for i in range(props['amount'])]
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -14,17 +14,26 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""Implements `FDTDSimNode`."""
|
||||
|
||||
import typing as typ
|
||||
|
||||
import sympy as sp
|
||||
import bpy
|
||||
import tidy3d as td
|
||||
import tidy3d.plugins.adjoint as tdadj
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
|
||||
from ... import contracts as ct
|
||||
from ... import sockets
|
||||
from .. import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class FDTDSimNode(base.MaxwellSimNode):
|
||||
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
|
||||
|
||||
node_type = ct.NodeType.FDTDSim
|
||||
bl_label = 'FDTD Simulation'
|
||||
|
||||
|
@ -35,51 +44,255 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
'BCs': sockets.MaxwellBoundCondsSocketDef(),
|
||||
'Domain': sockets.MaxwellSimDomainSocketDef(),
|
||||
'Sources': sockets.MaxwellSourceSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
'Structures': sockets.MaxwellStructureSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
'Monitors': sockets.MaxwellMonitorSocketDef(
|
||||
is_list=True,
|
||||
active_kind=ct.FlowKind.Array,
|
||||
),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(),
|
||||
output_socket_sets: typ.ClassVar = {
|
||||
'Single': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
|
||||
},
|
||||
'Batch': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array),
|
||||
},
|
||||
'Lazy': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
|
||||
},
|
||||
}
|
||||
|
||||
####################
|
||||
# - Output Socket Computation
|
||||
# - Properties
|
||||
####################
|
||||
differentiable: bool = bl_cache.BLField(False)
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||
layout.prop(
|
||||
self,
|
||||
self.blfields['differentiable'],
|
||||
text='Differentiable',
|
||||
toggle=True,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
props={'active_socket_set'},
|
||||
output_sockets={'Sim'},
|
||||
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||
)
|
||||
def on_any_changed(self, props, output_sockets) -> None:
|
||||
"""Create loose input sockets."""
|
||||
params = output_sockets['Sim']
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
# Declare Loose Sockets that Realize Symbols
|
||||
## -> This happens if Params contains not-yet-realized symbols.
|
||||
active_socket_set = props['active_socket_set']
|
||||
if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols:
|
||||
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
|
||||
self.loose_input_sockets = {
|
||||
sym.name: sockets.ExprSocketDef(
|
||||
**(
|
||||
expr_info
|
||||
| {
|
||||
'active_kind': ct.FlowKind.Value,
|
||||
'use_value_range_swapper': (
|
||||
active_socket_set == 'Value'
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
for sym, expr_info in params.sym_expr_infos.items()
|
||||
}
|
||||
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
####################
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Sim',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'Sources': ct.FlowKind.Array,
|
||||
'Structures': ct.FlowKind.Array,
|
||||
'Domain': ct.FlowKind.Value,
|
||||
'BCs': ct.FlowKind.Value,
|
||||
'Monitors': ct.FlowKind.Array,
|
||||
},
|
||||
output_sockets={'Sim'},
|
||||
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_fdtd_sim(self, input_sockets: dict) -> sp.Expr:
|
||||
if any(ct.FlowSignal.check(inp) for inp in input_sockets):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
def compute_fdtd_sim_value(
|
||||
self, props, input_sockets, output_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable."""
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
output_params = output_sockets['Sim']
|
||||
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
differentiable = props['differentiable']
|
||||
if (
|
||||
has_sim_domain
|
||||
and has_sources
|
||||
and has_structures
|
||||
and has_bounds
|
||||
and has_monitors
|
||||
and has_output_params
|
||||
and not differentiable
|
||||
):
|
||||
return td.Simulation(
|
||||
**sim_domain,
|
||||
structures=structures,
|
||||
sources=sources,
|
||||
monitors=monitors,
|
||||
structures=structures,
|
||||
boundary_spec=bounds,
|
||||
monitors=monitors,
|
||||
)
|
||||
## TODO: Visualize the boundary conditions on top of the sim domain
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Sim',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'Sources': ct.FlowKind.Func,
|
||||
'Structures': ct.FlowKind.Func,
|
||||
'Monitors': ct.FlowKind.Func,
|
||||
},
|
||||
output_sockets={'Sim'},
|
||||
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_fdtd_sim_func(
|
||||
self, props, input_sockets, output_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
output_params = output_sockets['Sim']
|
||||
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if (
|
||||
has_sim_domain
|
||||
and has_sources
|
||||
and has_structures
|
||||
and has_bounds
|
||||
and has_monitors
|
||||
and has_output_params
|
||||
):
|
||||
differentiable = props['differentiable']
|
||||
if differentiable:
|
||||
return (
|
||||
sim_domain | sources | structures | bounds | monitors
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: tdadj.JaxSimulation(
|
||||
**els[0],
|
||||
sources=els[1],
|
||||
structures=els[2]['static'],
|
||||
input_structures=els[2]['differentiable'],
|
||||
boundary_spec=els[3],
|
||||
monitors=els[4]['static'],
|
||||
output_monitors=els[4]['differentiable'],
|
||||
),
|
||||
supports_jax=True,
|
||||
)
|
||||
return (
|
||||
sim_domain | sources | structures | bounds | monitors
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: td.Simulation(
|
||||
**els[0],
|
||||
sources=els[1],
|
||||
structures=els[2],
|
||||
boundary_spec=els[3],
|
||||
monitors=els[4],
|
||||
),
|
||||
supports_jax=False,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Sim',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'Sources': ct.FlowKind.Params,
|
||||
'Structures': ct.FlowKind.Params,
|
||||
'Monitors': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_fdtd_sim_params(
|
||||
self, props, input_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
|
||||
if (
|
||||
has_sim_domain
|
||||
and has_sources
|
||||
and has_structures
|
||||
and has_bounds
|
||||
and has_monitors
|
||||
):
|
||||
# Determine Differentiable Match
|
||||
## -> 'structures' is diff when **any** are diff.
|
||||
## -> 'monitors' is also diff when **any** are diff.
|
||||
## -> Only parameters through diff structs can be diff'ed by.
|
||||
## -> Similarly, only diff monitors will have gradients computed.
|
||||
return sim_domain | sources | structures | bounds | monitors
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""Implements `SimDomainNode`."""
|
||||
|
||||
import typing as typ
|
||||
|
||||
import sympy as sp
|
||||
|
@ -31,6 +33,8 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
class SimDomainNode(base.MaxwellSimNode):
|
||||
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
|
||||
|
||||
node_type = ct.NodeType.SimDomain
|
||||
bl_label = 'Sim Domain'
|
||||
use_sim_node_name = True
|
||||
|
@ -69,26 +73,109 @@ class SimDomainNode(base.MaxwellSimNode):
|
|||
}
|
||||
|
||||
####################
|
||||
# - Outputs
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Domain',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
output_sockets={'Domain'},
|
||||
output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Domain'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Domain'][ct.FlowKind.Params]
|
||||
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Domain',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Duration': 'Tidy3DUnits',
|
||||
'Center': 'Tidy3DUnits',
|
||||
'Size': 'Tidy3DUnits',
|
||||
input_socket_kinds={
|
||||
'Duration': ct.FlowKind.Func,
|
||||
'Center': ct.FlowKind.Func,
|
||||
'Size': ct.FlowKind.Func,
|
||||
'Grid': ct.FlowKind.Func,
|
||||
'Ambient Medium': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_domain(self, input_sockets, unit_systems) -> sp.Expr:
|
||||
return {
|
||||
'run_time': input_sockets['Duration'],
|
||||
'center': input_sockets['Center'],
|
||||
'size': input_sockets['Size'],
|
||||
'grid_spec': input_sockets['Grid'],
|
||||
'medium': input_sockets['Ambient Medium'],
|
||||
}
|
||||
def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
duration = input_sockets['Duration']
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
grid = input_sockets['Grid']
|
||||
medium = input_sockets['Ambient Medium']
|
||||
|
||||
has_duration = not ct.FlowSignal.check(duration)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_grid = not ct.FlowSignal.check(grid)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_duration and has_center and has_size and has_grid and has_medium:
|
||||
return (
|
||||
duration.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| grid
|
||||
| medium
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: {
|
||||
'run_time': els[0],
|
||||
'center': tuple(els[1].flatten()),
|
||||
'size': tuple(els[2].flatten()),
|
||||
'grid_spec': els[3],
|
||||
'medium': els[4],
|
||||
},
|
||||
supports_jax=False,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Domain',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
|
||||
input_socket_kinds={
|
||||
'Duration': ct.FlowKind.Params,
|
||||
'Center': ct.FlowKind.Params,
|
||||
'Size': ct.FlowKind.Params,
|
||||
'Grid': ct.FlowKind.Params,
|
||||
'Ambient Medium': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
|
||||
duration = input_sockets['Duration']
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
grid = input_sockets['Grid']
|
||||
medium = input_sockets['Ambient Medium']
|
||||
|
||||
has_duration = not ct.FlowSignal.check(duration)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_grid = not ct.FlowSignal.check(grid)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_duration and has_center and has_size and has_grid and has_medium:
|
||||
return duration | center | size | grid | medium
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Preview
|
||||
|
@ -100,37 +187,39 @@ class SimDomainNode(base.MaxwellSimNode):
|
|||
props={'sim_node_name'},
|
||||
)
|
||||
def compute_previews(self, props):
|
||||
"""Mark the managed preview object for preview when `Domain` is linked to a viewer."""
|
||||
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
|
||||
|
||||
@events.on_value_changed(
|
||||
## Trigger
|
||||
# Trigger
|
||||
socket_name={'Center', 'Size'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
input_sockets={'Center', 'Size'},
|
||||
managed_objs={'modifier'},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
scale_input_sockets={
|
||||
'Center': 'BlenderUnits',
|
||||
},
|
||||
output_sockets={'Domain'},
|
||||
output_socket_kinds={'Domain': ct.FlowKind.Params},
|
||||
)
|
||||
def on_input_changed(
|
||||
self,
|
||||
managed_objs,
|
||||
input_sockets,
|
||||
unit_systems,
|
||||
):
|
||||
def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
|
||||
"""Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
|
||||
output_params = output_sockets['Domain']
|
||||
center = input_sockets['Center']
|
||||
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
|
||||
if has_center and has_output_params and not output_params.symbols:
|
||||
# Push Loose Input Values to GeoNodes Modifier
|
||||
managed_objs['modifier'].bl_modifier(
|
||||
'NODES',
|
||||
{
|
||||
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
|
||||
'unit_system': unit_systems['BlenderUnits'],
|
||||
'unit_system': ct.UNITS_BLENDER,
|
||||
'inputs': {
|
||||
'Size': input_sockets['Size'],
|
||||
},
|
||||
},
|
||||
location=input_sockets['Center'],
|
||||
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -71,35 +71,129 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['pol_axis'], expand=True)
|
||||
|
||||
####################
|
||||
# - Outputs
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Source',
|
||||
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||
# Loaded
|
||||
props={'pol_axis'},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Center': 'Tidy3DUnits',
|
||||
},
|
||||
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||
output_sockets={'Source'},
|
||||
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_source(
|
||||
self,
|
||||
input_sockets: dict[str, typ.Any],
|
||||
props: dict[str, typ.Any],
|
||||
unit_systems: dict,
|
||||
) -> td.PointDipole:
|
||||
def compute_source_value(
|
||||
self, input_sockets, props, output_sockets
|
||||
) -> td.PointDipole | ct.FlowSignal:
|
||||
"""Compute the point dipole source, given that all inputs are non-symbolic."""
|
||||
temporal_shape = input_sockets['Temporal Shape']
|
||||
center = input_sockets['Center']
|
||||
interpolate = input_sockets['Interpolate']
|
||||
output_params = output_sockets['Source']
|
||||
|
||||
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if (
|
||||
has_temporal_shape
|
||||
and has_center
|
||||
and has_interpolate
|
||||
and has_output_params
|
||||
and not output_params.symbols
|
||||
):
|
||||
pol_axis = {
|
||||
ct.SimSpaceAxis.X: 'Ex',
|
||||
ct.SimSpaceAxis.Y: 'Ey',
|
||||
ct.SimSpaceAxis.Z: 'Ez',
|
||||
}[props['pol_axis']]
|
||||
## TODO: Need Hx, Hy, Hz too?
|
||||
|
||||
return td.PointDipole(
|
||||
center=input_sockets['Center'],
|
||||
source_time=input_sockets['Temporal Shape'],
|
||||
interpolate=input_sockets['Interpolate'],
|
||||
center=spux.convert_to_unit_system(center, ct.UNITS_TIDY3D),
|
||||
source_time=temporal_shape,
|
||||
interpolate=interpolate,
|
||||
polarization=pol_axis,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Source',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'pol_axis'},
|
||||
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||
input_socket_kinds={
|
||||
'Temporal Shape': ct.FlowKind.Func,
|
||||
'Center': ct.FlowKind.Func,
|
||||
'Interpolate': ct.FlowKind.Func,
|
||||
},
|
||||
output_sockets={'Source'},
|
||||
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_source_func(self, props, input_sockets, output_sockets) -> td.Box:
|
||||
"""Compute a lazy function for the point dipole source."""
|
||||
center = input_sockets['Center']
|
||||
temporal_shape = input_sockets['Temporal Shape']
|
||||
interpolate = input_sockets['Interpolate']
|
||||
output_params = output_sockets['Source']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_temporal_shape and has_center and has_interpolate and has_output_params:
|
||||
pol_axis = {
|
||||
ct.SimSpaceAxis.X: 'Ex',
|
||||
ct.SimSpaceAxis.Y: 'Ey',
|
||||
ct.SimSpaceAxis.Z: 'Ez',
|
||||
}[props['pol_axis']]
|
||||
## TODO: Need Hx, Hy, Hz too?
|
||||
|
||||
return (center | temporal_shape | interpolate).compose_within(
|
||||
enclosing_func=lambda els: td.PointDipole(
|
||||
center=els[0],
|
||||
source_time=els[1],
|
||||
interpolate=els[2],
|
||||
polarization=pol_axis,
|
||||
)
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Source',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||
input_socket_kinds={
|
||||
'Temporal Shape': ct.FlowKind.Params,
|
||||
'Center': ct.FlowKind.Params,
|
||||
'Interpolate': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_params(
|
||||
self,
|
||||
input_sockets,
|
||||
) -> td.PointDipole | ct.FlowSignal:
|
||||
"""Compute the point dipole source, given that all inputs are non-symbolic."""
|
||||
temporal_shape = input_sockets['Temporal Shape']
|
||||
center = input_sockets['Center']
|
||||
interpolate = input_sockets['Interpolate']
|
||||
|
||||
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||
|
||||
if has_temporal_shape and has_center and has_interpolate:
|
||||
return temporal_shape | center | interpolate
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Preview
|
||||
|
|
|
@ -16,15 +16,19 @@
|
|||
|
||||
"""Implements the `TemporalShapeNode`."""
|
||||
|
||||
import enum
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
import tidy3d as td
|
||||
from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
|
||||
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from ... import contracts as ct
|
||||
from ... import managed_objs, sockets
|
||||
|
@ -33,14 +37,10 @@ from .. import base, events
|
|||
log = logger.get(__name__)
|
||||
|
||||
|
||||
_max_e_socket_def = sockets.ExprSocketDef(
|
||||
mathtype=spux.MathType.Complex,
|
||||
physical_type=spux.PhysicalType.EField,
|
||||
default_value=1 + 0j,
|
||||
)
|
||||
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
|
||||
|
||||
t_ps = sim_symbols.t(spu.picosecond)
|
||||
# Select Default Time Unit for Envelope
|
||||
## -> Chosen to align with the default envelope_time_unit.
|
||||
## -> This causes it to be correct from the start.
|
||||
t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
|
||||
|
||||
|
||||
class TemporalShapeNode(base.MaxwellSimNode):
|
||||
|
@ -63,17 +63,18 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
default_unit=spux.THz,
|
||||
default_value=200,
|
||||
),
|
||||
'max E': sockets.ExprSocketDef(
|
||||
mathtype=spux.MathType.Complex,
|
||||
physical_type=spux.PhysicalType.EField,
|
||||
default_value=1 + 0j,
|
||||
),
|
||||
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Pulse': {
|
||||
'max E': _max_e_socket_def,
|
||||
'Offset Time': _offset_socket_def,
|
||||
'Remove DC': sockets.BoolSocketDef(default_value=True),
|
||||
},
|
||||
'Constant': {
|
||||
'max E': _max_e_socket_def,
|
||||
'Offset Time': _offset_socket_def,
|
||||
},
|
||||
'Constant': {},
|
||||
'Symbolic': {
|
||||
't Range': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Range,
|
||||
|
@ -84,8 +85,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
default_steps=100,
|
||||
),
|
||||
'Envelope': sockets.ExprSocketDef(
|
||||
default_symbols=[t_ps],
|
||||
default_value=10 * t_ps.sp_symbol,
|
||||
default_symbols=[t_def],
|
||||
default_value=10 * t_def.sp_symbol,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
@ -98,6 +99,55 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
'plot': managed_objs.ManagedBLImage,
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
active_envelope_time_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_time_units(),
|
||||
)
|
||||
|
||||
def search_time_units(self) -> list[ct.BLEnumElement]:
|
||||
"""Compute all valid time units."""
|
||||
return [
|
||||
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
|
||||
for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
|
||||
]
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
|
||||
def envelope_time_unit(self) -> spux.Unit | None:
|
||||
"""Gets the current active unit for the envelope time symbol.
|
||||
|
||||
Returns:
|
||||
The current active `sympy` unit.
|
||||
|
||||
If the socket expression is unitless, this returns `None`.
|
||||
"""
|
||||
if self.active_envelope_time_unit is not None:
|
||||
return spux.unit_str_to_unit(self.active_envelope_time_unit)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||
if (
|
||||
self.active_socket_set == 'Symbolic'
|
||||
and self.inputs.get('Envelope')
|
||||
and not self.inputs['Envelope'].is_linked
|
||||
):
|
||||
row = layout.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Envelope Time Unit')
|
||||
|
||||
row = layout.row()
|
||||
row.prop(
|
||||
self,
|
||||
self.blfields['active_envelope_time_unit'],
|
||||
text='',
|
||||
toggle=True,
|
||||
)
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
if self.active_socket_set != 'Symbolic':
|
||||
box = layout.box()
|
||||
|
@ -118,10 +168,53 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
col.label(text='1 / 2π·σ(𝑓)')
|
||||
|
||||
####################
|
||||
# - FlowKind: Value
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
prop_name={'active_socket_set', 'envelope_time_unit'},
|
||||
# Loaded
|
||||
props={'active_socket_set', 'envelope_time_unit'},
|
||||
)
|
||||
def on_envelope_time_unit_changed(self, props) -> None:
|
||||
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
|
||||
active_socket_set = props['active_socket_set']
|
||||
envelope_time_unit = props['envelope_time_unit']
|
||||
if active_socket_set == 'Symbolic':
|
||||
bl_socket = self.inputs['Envelope']
|
||||
wanted_t_sym = sim_symbols.t(envelope_time_unit)
|
||||
|
||||
if not bl_socket.symbols or bl_socket.symbols[0] != wanted_t_sym:
|
||||
bl_socket.symbols = [wanted_t_sym]
|
||||
|
||||
####################
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Temporal Shape',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
output_sockets={'Temporal Shape'},
|
||||
output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute a single temporal shape."""
|
||||
output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
|
||||
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind: Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Temporal Shape',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'active_socket_set'},
|
||||
input_sockets={
|
||||
|
@ -134,60 +227,178 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
'Envelope',
|
||||
},
|
||||
input_socket_kinds={
|
||||
't Range': ct.FlowKind.Range,
|
||||
'Envelope': ct.FlowKind.Func,
|
||||
},
|
||||
input_sockets_optional={
|
||||
'max E': True,
|
||||
'Offset Time': True,
|
||||
'Remove DC': True,
|
||||
't Range': True,
|
||||
'Envelope': True,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'max E': 'Tidy3DUnits',
|
||||
'μ Freq': 'Tidy3DUnits',
|
||||
'σ Freq': 'Tidy3DUnits',
|
||||
't Range': 'Tidy3DUnits',
|
||||
'Offset Time': 'Tidy3DUnits',
|
||||
'max E': ct.FlowKind.Func,
|
||||
'μ Freq': ct.FlowKind.Func,
|
||||
'σ Freq': ct.FlowKind.Func,
|
||||
'Offset Time': ct.FlowKind.Func,
|
||||
'Remove DC': ct.FlowKind.Value,
|
||||
't Range': ct.FlowKind.Func,
|
||||
'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
|
||||
},
|
||||
)
|
||||
def compute_temporal_shape(
|
||||
self, props, input_sockets, unit_systems
|
||||
def compute_temporal_shape_func(
|
||||
self,
|
||||
props,
|
||||
input_sockets,
|
||||
) -> td.GaussianPulse:
|
||||
"""Compute a single temporal shape from non-parameterized inputs."""
|
||||
mean_freq = input_sockets['μ Freq']
|
||||
std_freq = input_sockets['σ Freq']
|
||||
max_e = input_sockets['max E']
|
||||
offset = input_sockets['Offset Time']
|
||||
|
||||
has_mean_freq = not ct.FlowSignal.check(mean_freq)
|
||||
has_std_freq = not ct.FlowSignal.check(std_freq)
|
||||
has_max_e = not ct.FlowSignal.check(max_e)
|
||||
has_offset = not ct.FlowSignal.check(offset)
|
||||
|
||||
if has_mean_freq and has_std_freq and has_max_e and has_offset:
|
||||
common_func = (
|
||||
max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| offset ## Already unitless
|
||||
)
|
||||
match props['active_socket_set']:
|
||||
case 'Pulse':
|
||||
return td.GaussianPulse(
|
||||
amplitude=sp.re(input_sockets['max E']),
|
||||
phase=sp.im(input_sockets['max E']),
|
||||
freq0=input_sockets['μ Freq'],
|
||||
fwidth=input_sockets['σ Freq'],
|
||||
offset=input_sockets['Offset Time'],
|
||||
remove_dc_component=input_sockets['Remove DC'],
|
||||
remove_dc = input_sockets['Remove DC']
|
||||
|
||||
has_remove_dc = not ct.FlowSignal.check(remove_dc)
|
||||
|
||||
if has_remove_dc:
|
||||
return common_func.compose_within(
|
||||
lambda els: td.GaussianPulse(
|
||||
amplitude=complex(els[0]).real,
|
||||
phase=complex(els[0]).imag,
|
||||
freq0=els[1],
|
||||
fwidth=els[2],
|
||||
offset=els[3],
|
||||
remove_dc_component=remove_dc,
|
||||
),
|
||||
)
|
||||
|
||||
case 'Constant':
|
||||
return td.ContinuousWave(
|
||||
amplitude=sp.re(input_sockets['max E']),
|
||||
phase=sp.im(input_sockets['max E']),
|
||||
freq0=input_sockets['μ Freq'],
|
||||
fwidth=input_sockets['σ Freq'],
|
||||
offset=input_sockets['Offset Time'],
|
||||
return common_func.compose_within(
|
||||
lambda els: td.GaussianPulse(
|
||||
amplitude=complex(els[0]).real,
|
||||
phase=complex(els[0]).imag,
|
||||
freq0=els[1],
|
||||
fwidth=els[2],
|
||||
offset=els[3],
|
||||
),
|
||||
)
|
||||
|
||||
case 'Symbolic':
|
||||
lzrange = input_sockets['t Range']
|
||||
envelope_ps = input_sockets['Envelope'].func_jax
|
||||
t_range = input_sockets['t Range']
|
||||
envelope = input_sockets['Envelope'][ct.FlowKind.Func]
|
||||
envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
|
||||
|
||||
return td.CustomSourceTime.from_values(
|
||||
freq0=input_sockets['μ Freq'],
|
||||
fwidth=input_sockets['σ Freq'],
|
||||
values=envelope_ps(
|
||||
lzrange.rescale_to_unit(spu.ps).realize_array.values
|
||||
),
|
||||
dt=input_sockets['t Range'].realize_step_size(),
|
||||
has_t_range = not ct.FlowSignal.check(t_range)
|
||||
has_envelope = not ct.FlowSignal.check(envelope)
|
||||
has_envelope_params = not ct.FlowSignal.check(envelope_params)
|
||||
|
||||
if (
|
||||
has_t_range
|
||||
and has_envelope
|
||||
and has_envelope_params
|
||||
and len(envelope_params.symbols) == 1
|
||||
## TODO: Allow unrealized envelope symbols
|
||||
and any(
|
||||
sym.physical_type is spux.PhysicalType.Time
|
||||
for sym in envelope_params.symbols
|
||||
)
|
||||
):
|
||||
envelope_time_unit = next(
|
||||
sym.unit
|
||||
for sym in envelope_params.symbols
|
||||
if sym.physical_type is spux.PhysicalType.Time
|
||||
)
|
||||
|
||||
# Deduce Partially Realized Envelope Function
|
||||
## -> We need a pure-numerical function w/pre-realized stuff baked in.
|
||||
## -> 'realize_partial' does this for us.
|
||||
envelope_realizer = envelope.realize_partial(envelope_params)
|
||||
|
||||
# Compose w/Envelope Function
|
||||
## -> First, the numerical time values must be converted.
|
||||
## -> This ensures that the raw array is compatible w/the envelope.
|
||||
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
|
||||
## -> Because of the checks, we've guaranteed that all this is correct.
|
||||
return (
|
||||
common_func ## 1 | freq0, 2 | fwidth, 3 | offset
|
||||
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
|
||||
| t_range.scale_to_unit(envelope_time_unit).compose_within(
|
||||
lambda t: envelope_realizer(t)
|
||||
) ## 5
|
||||
).compose_within(
|
||||
lambda els: td.CustomSourceTime(
|
||||
amplitude=complex(els[0]).real,
|
||||
phase=complex(els[0]).imag,
|
||||
freq0=els[1],
|
||||
fwidth=els[2],
|
||||
offset=els[3],
|
||||
source_time_dataset=td_TimeDataset(
|
||||
values=td_TimeDataArray(
|
||||
els[5], coords={'t': np.array(els[4])}
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind: Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Temporal Shape',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
props={'active_socket_set', 'envelope_time_unit'},
|
||||
input_sockets={
|
||||
'max E',
|
||||
'μ Freq',
|
||||
'σ Freq',
|
||||
'Offset Time',
|
||||
't Range',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'max E': ct.FlowKind.Params,
|
||||
'μ Freq': ct.FlowKind.Params,
|
||||
'σ Freq': ct.FlowKind.Params,
|
||||
'Offset Time': ct.FlowKind.Params,
|
||||
't Range': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_temporal_shape_params(
|
||||
self,
|
||||
props,
|
||||
input_sockets,
|
||||
) -> td.GaussianPulse:
|
||||
"""Compute a single temporal shape from non-parameterized inputs."""
|
||||
mean_freq = input_sockets['μ Freq']
|
||||
std_freq = input_sockets['σ Freq']
|
||||
max_e = input_sockets['max E']
|
||||
offset = input_sockets['Offset Time']
|
||||
|
||||
has_mean_freq = not ct.FlowSignal.check(mean_freq)
|
||||
has_std_freq = not ct.FlowSignal.check(std_freq)
|
||||
has_max_e = not ct.FlowSignal.check(max_e)
|
||||
has_offset = not ct.FlowSignal.check(offset)
|
||||
|
||||
if has_mean_freq and has_std_freq and has_max_e and has_offset:
|
||||
common_params = max_e | mean_freq | std_freq | offset
|
||||
match props['active_socket_set']:
|
||||
case 'Pulse' | 'Constant':
|
||||
return common_params
|
||||
|
||||
case 'Symbolic':
|
||||
t_range = input_sockets['t Range']
|
||||
has_t_range = not ct.FlowSignal.check(t_range)
|
||||
|
||||
if has_t_range:
|
||||
return common_params | t_range | t_range
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -88,28 +88,27 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'Structure',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Medium', 'Center', 'Size'},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_value(self, props, input_sockets, output_sockets) -> td.Box:
|
||||
output_params = output_sockets['Structure']
|
||||
def compute_value(self, input_sockets, output_sockets) -> td.Box:
|
||||
"""Compute a single box structure object, given that all inputs are non-symbolic."""
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
medium = input_sockets['Medium']
|
||||
output_params = output_sockets['Structure']
|
||||
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if (
|
||||
has_center
|
||||
and has_size
|
||||
and has_medium
|
||||
and has_output_params
|
||||
and not props['differentiable']
|
||||
and not output_params.symbols
|
||||
):
|
||||
return td.Structure(
|
||||
|
@ -138,7 +137,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_lazy_structure(self, props, input_sockets, output_sockets) -> td.Box:
|
||||
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
|
||||
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
|
||||
output_params = output_sockets['Structure']
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
|
@ -149,14 +149,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
has_size = not ct.FlowSignal.check(size)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_output_params and has_center and has_size and has_medium:
|
||||
differentiable = props['differentiable']
|
||||
if (
|
||||
has_output_params
|
||||
and has_center
|
||||
and has_size
|
||||
and has_medium
|
||||
and differentiable == output_params.is_differentiable
|
||||
):
|
||||
if differentiable:
|
||||
return (center | size | medium).compose_within(
|
||||
enclosing_func=lambda els: tdadj.JaxStructure(
|
||||
|
@ -169,6 +163,12 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
supports_jax=True,
|
||||
)
|
||||
return (center | size | medium).compose_within(
|
||||
## TODO: Unit conversion within the composed function??
|
||||
## -- We do need Tidy3D to be given ex. micrometers in particular.
|
||||
## -- But the previous numerical output might not be micrometers.
|
||||
## -- There must be a way to add a conversion in, without strangeness.
|
||||
## -- Ex. can compose_within() take a unit system?
|
||||
## -- This would require
|
||||
enclosing_func=lambda els: td.Structure(
|
||||
geometry=td.Box(
|
||||
center=tuple(els[0].flatten()),
|
||||
|
@ -205,14 +205,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_center and has_size and has_medium:
|
||||
if props['differentiable'] == (
|
||||
center.is_differentiable
|
||||
and size.is_differentiable
|
||||
and medium.is_differentiable
|
||||
):
|
||||
return center | size | medium
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Events: Preview
|
||||
|
@ -226,6 +220,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_previews(self, props, output_sockets):
|
||||
"""Mark the managed preview object when recursively linked to a viewer."""
|
||||
output_params = output_sockets['Structure']
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
|
@ -245,10 +240,14 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
)
|
||||
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
|
||||
output_params = output_sockets['Structure']
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
if has_output_params and not output_params.symbols:
|
||||
# Push Loose Input Values to GeoNodes Modifier
|
||||
center = input_sockets['Center']
|
||||
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
if has_center and has_output_params and not output_params.symbols:
|
||||
## TODO: There are strategies for handling examples of symbol values.
|
||||
|
||||
# Push Loose Input Values to GeoNodes Modifier
|
||||
managed_objs['modifier'].bl_modifier(
|
||||
'NODES',
|
||||
{
|
||||
|
|
|
@ -43,17 +43,28 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
|||
"""
|
||||
|
||||
socket_type: ct.SocketType
|
||||
active_kind: typ.Literal[
|
||||
ct.FlowKind.Value,
|
||||
ct.FlowKind.Array,
|
||||
ct.FlowKind.Range,
|
||||
ct.FlowKind.Func,
|
||||
] = ct.FlowKind.Value
|
||||
|
||||
####################
|
||||
# - Socket Interaction
|
||||
####################
|
||||
def preinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
"""Pre-initialize a real Blender node socket from this socket definition.
|
||||
|
||||
Parameters:
|
||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||
"""
|
||||
log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
|
||||
# log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
|
||||
bl_socket.reset_instance_id()
|
||||
bl_socket.regenerate_dynamic_field_persistance()
|
||||
log.debug('%s: End Socket Preinit', bl_socket.bl_label)
|
||||
|
||||
bl_socket.active_kind = self.active_kind
|
||||
# log.debug('%s: End Socket Preinit', bl_socket.bl_label)
|
||||
|
||||
def postinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
"""Pre-initialize a real Blender node socket from this socket definition.
|
||||
|
@ -61,12 +72,12 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
|||
Parameters:
|
||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||
"""
|
||||
log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
|
||||
# log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
|
||||
bl_socket.is_initializing = False
|
||||
bl_socket.on_active_kind_changed()
|
||||
bl_socket.on_socket_props_changed(set(bl_socket.blfields))
|
||||
bl_socket.on_data_changed(set(ct.FlowKind))
|
||||
log.debug('%s: End Socket Postinit', bl_socket.bl_label)
|
||||
# log.debug('%s: End Socket Postinit', bl_socket.bl_label)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
|
@ -76,6 +87,43 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
|||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||
"""
|
||||
|
||||
####################
|
||||
# - Comparison
|
||||
####################
|
||||
def compare(self, bl_socket: bpy.types.NodeSocket) -> bool:
|
||||
"""Whether this `SocketDef` can be considered to uniquely define the given `bl_socket`.
|
||||
|
||||
The general criteria for "uniquely defines" is whether **the same `bl_socket`** could be created using this `SocketDef`.
|
||||
The extent to which user-altered properties are considered in this regard is a matter of taste, encapsulated entirely within `self.local_compare()`.
|
||||
|
||||
Notes:
|
||||
Used when determining whether to replace sockets with newer variants when synchronizing changes.
|
||||
|
||||
**NOTE**: Removing/replacing loose input sockets
|
||||
|
||||
Parameters:
|
||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||
"""
|
||||
return (
|
||||
bl_socket.socket_type is self.socket_type
|
||||
and bl_socket.active_kind is self.active_kind
|
||||
and self.local_compare(bl_socket)
|
||||
)
|
||||
|
||||
def local_compare(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
"""Compare this `SocketDef` to an established `bl_socket` in a manner specific to the node.
|
||||
|
||||
Notes:
|
||||
Run by `self.compare()`.
|
||||
Optionally overriden by individual sockets.
|
||||
|
||||
When not overridden, it will always return `False`, indicating that the socket is _never_ uniquely defined by this `SocketDef`.
|
||||
|
||||
Parameters:
|
||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||
"""
|
||||
return False
|
||||
|
||||
####################
|
||||
# - Serialization
|
||||
####################
|
||||
|
@ -426,8 +474,34 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
Parameters:
|
||||
socket_kinds: The altered `ct.FlowKind`s flowing through.
|
||||
"""
|
||||
# Run Socket Callbacks
|
||||
self.on_socket_data_changed(socket_kinds)
|
||||
|
||||
# Mark Active FlowKind Links as Invalid
|
||||
## -> Mark link as invalid (very red) if a FlowSignal is traveling.
|
||||
## -> This helps explain why whatever isn't working isn't working.
|
||||
## -> TODO: We need a different approach.
|
||||
# log.debug(
|
||||
# '[%s] Checking FlowKind Validity (socket_kinds=%s)',
|
||||
# self.name,
|
||||
# str(socket_kinds),
|
||||
# )
|
||||
# if self.is_linked and not self.is_output:
|
||||
# link = self.links[0]
|
||||
# linked_flow = self.compute_data(kind=self.active_kind)
|
||||
|
||||
# if (
|
||||
# link.is_valid
|
||||
# and self.active_kind in socket_kinds
|
||||
# and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
|
||||
# ):
|
||||
# node_tree = self.id_data
|
||||
# node_tree.report_link_validity(link, False)
|
||||
|
||||
# elif not link.is_valid:
|
||||
# node_tree = self.id_data
|
||||
# node_tree.report_link_validity(link, True)
|
||||
|
||||
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
|
||||
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
|
||||
|
||||
|
@ -479,7 +553,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows.
|
||||
"""
|
||||
# log.debug(
|
||||
# '[%s] [%s] Triggered (socket_kinds=%s)',
|
||||
# '[%s] [%s] Socket Triggered (socket_kinds=%s)',
|
||||
# self.name,
|
||||
# event,
|
||||
# str(socket_kinds),
|
||||
|
@ -757,7 +831,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
|
||||
|
||||
# Return Single Value / List of Values
|
||||
## -> Multi-input sockets are not yet supported.
|
||||
## -> Multi-input sockets are not (yet) supported.
|
||||
if linked_values:
|
||||
return linked_values[0]
|
||||
|
||||
|
@ -891,10 +965,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
# FlowKind Draw Row
|
||||
col = row.column(align=True)
|
||||
{
|
||||
ct.FlowKind.Capabilities: lambda *_: None,
|
||||
ct.FlowKind.Previews: lambda *_: None,
|
||||
ct.FlowKind.Value: self.draw_value,
|
||||
ct.FlowKind.Array: self.draw_array,
|
||||
ct.FlowKind.Range: self.draw_lazy_range,
|
||||
ct.FlowKind.Func: self.draw_lazy_func,
|
||||
ct.FlowKind.Params: lambda *_: None,
|
||||
ct.FlowKind.Info: lambda *_: None,
|
||||
}[self.active_kind](col)
|
||||
|
||||
# Info Drawing
|
||||
|
|
|
@ -51,6 +51,16 @@ class BoolBLSocket(base.MaxwellSimSocket):
|
|||
def value(self, value: bool) -> None:
|
||||
self.raw_value = value
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: self.value,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def params(self) -> ct.FuncFlow:
|
||||
return ct.ParamsFlow()
|
||||
|
||||
|
||||
####################
|
||||
# - Socket Configuration
|
||||
|
|
|
@ -130,6 +130,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
'physical_type',
|
||||
'unit',
|
||||
'size',
|
||||
'value',
|
||||
}
|
||||
)
|
||||
def output_sym(self) -> sim_symbols.SimSymbol | None:
|
||||
|
@ -140,13 +141,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
Raises:
|
||||
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
|
||||
"""
|
||||
if self.symbols:
|
||||
if self.active_kind in [ct.FlowKind.Value, ct.FlowKind.Func]:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
|
||||
return self._parse_expr_symbol(
|
||||
self._parse_expr_str(self.raw_value_spstr)
|
||||
)
|
||||
|
||||
if self.active_kind is ct.FlowKind.Range:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=self.output_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
rows=self.size.rows,
|
||||
cols=self.size.cols,
|
||||
exclude_zero=(
|
||||
not self.value.is_zero
|
||||
if self.value.is_zero is not None
|
||||
else False
|
||||
),
|
||||
## TODO: Does this work for matrix elements?
|
||||
)
|
||||
|
||||
case ct.FlowKind.Range if self.symbols:
|
||||
## TODO: Support RangeFlow
|
||||
## -- It's hard; we need a min-span set over bound domains.
|
||||
## -- We... Don't use this anywhere. Yet?
|
||||
|
@ -159,20 +176,37 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
case ct.FlowKind.Range if not self.symbols:
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=self.output_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
rows=self.size.rows,
|
||||
cols=self.size.cols,
|
||||
rows=self.lazy_range.steps,
|
||||
cols=1,
|
||||
exclude_zero=not self.lazy_range.is_always_nonzero,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Value|Range Swapper
|
||||
####################
|
||||
use_value_range_swapper: bool = bl_cache.BLField(False)
|
||||
selected_value_range: ct.FlowKind = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self._value_or_range(),
|
||||
)
|
||||
|
||||
def _value_or_range(self):
|
||||
return [
|
||||
flow_kind.bl_enum_element(i)
|
||||
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
|
||||
]
|
||||
|
||||
####################
|
||||
# - Symbols
|
||||
####################
|
||||
lazy_range_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Expr
|
||||
)
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Expr
|
||||
)
|
||||
|
@ -343,7 +377,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
See `MaxwellSimTree` for more detail on the link callbacks.
|
||||
"""
|
||||
## NODE: Depends on suppressed on_prop_changed
|
||||
## NOTE: Depends on suppressed on_prop_changed
|
||||
|
||||
if ct.FlowKind.Info in socket_kinds:
|
||||
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||
|
@ -371,7 +405,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
See `MaxwellSimTree` for more detail on the link callbacks.
|
||||
"""
|
||||
## NODE: Depends on suppressed on_prop_changed
|
||||
## NOTE: Depends on suppressed on_prop_changed
|
||||
if ('selected_value_range', 'invalidate') in cleared_blfields:
|
||||
self.active_kind = self.selected_value_range
|
||||
self.on_active_kind_changed()
|
||||
|
||||
# Conditional Unit-Conversion
|
||||
## -> This is niche functionality, but the only way to convert units.
|
||||
|
@ -757,7 +794,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
@bl_cache.cached_bl_property(
|
||||
depends_on={
|
||||
'value',
|
||||
'symbols',
|
||||
'sorted_sp_symbols',
|
||||
'sorted_symbols',
|
||||
'output_sym',
|
||||
|
@ -769,83 +805,88 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
|
||||
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
||||
"""
|
||||
# Symbolic
|
||||
## -> `self.value` is guaranteed to be an expression with unknowns.
|
||||
## -> The function computes `self.value` with unknowns as arguments.
|
||||
if self.symbols:
|
||||
value = self.value
|
||||
has_value = not ct.FlowSignal.check(value)
|
||||
|
||||
output_sym = self.output_sym
|
||||
if output_sym is not None and has_value:
|
||||
if self.output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||
):
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
output_sym.conform(value, strip_unit=True),
|
||||
self.output_sym.conform(self.value, strip_unit=True),
|
||||
'jax',
|
||||
),
|
||||
func_args=list(self.sorted_symbols),
|
||||
func_output=self.output_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Constant
|
||||
## -> When a `self.value` has no unknowns, use a dummy function.
|
||||
## -> ("Dummy" as in returns the same argument that it takes).
|
||||
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
|
||||
## -> Generally only useful for operations with other expressions.
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
|
||||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[self.output_sym],
|
||||
func_output=self.output_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols'})
|
||||
def is_differentiable(self) -> bool:
|
||||
"""Whether all symbols are differentiable.
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
If there are no symbols, then there is nothing to differentiate, and thus the expression is differentiable.
|
||||
"""
|
||||
if not self.sorted_symbols:
|
||||
return True
|
||||
|
||||
return all(
|
||||
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
||||
for sym in self.sorted_symbols
|
||||
case ct.FlowKind.Range if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[self.output_sym],
|
||||
func_output=self.output_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym', 'value'})
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@bl_cache.cached_bl_property(
|
||||
depends_on={'sorted_symbols', 'output_sym', 'value', 'lazy_range'}
|
||||
)
|
||||
def params(self) -> ct.ParamsFlow:
|
||||
"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||
|
||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
||||
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
||||
"""
|
||||
# Symbolic
|
||||
## -> The Expr socket does not declare actual values for the symbols.
|
||||
## -> They should be realized later, ex. in a Viz node.
|
||||
## -> Therefore, we just dump the symbols. Easy!
|
||||
## -> NOTE: func_args must have the same symbol order as was lambdified.
|
||||
if self.sorted_symbols:
|
||||
output_sym = self.output_sym
|
||||
if output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=list(self.sorted_symbols),
|
||||
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
|
||||
symbols=self.sorted_symbols,
|
||||
is_differentiable=self.is_differentiable,
|
||||
symbols=set(self.sorted_symbols),
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Constant
|
||||
## -> Simply pass self.value verbatim as a function argument.
|
||||
## -> Easy dice, easy life!
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||
):
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=[self.output_sym],
|
||||
func_args=[self.value],
|
||||
is_differentiable=self.is_differentiable,
|
||||
)
|
||||
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
case ct.FlowKind.Range if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=[self.output_sym],
|
||||
func_args=[self.output_sym.sp_symbol_matsym],
|
||||
symbols={self.output_sym},
|
||||
).realize_partial({self.output_sym: self.lazy_range})
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
|
||||
def info(self) -> ct.InfoFlow:
|
||||
r"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||
|
@ -858,22 +899,34 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||
"""
|
||||
# Constant
|
||||
## -> The input SimSymbols become continuous dimensional indices.
|
||||
## -> All domain validity information is defined on the SimSymbol keys.
|
||||
if self.sorted_symbols:
|
||||
output_sym = self.output_sym
|
||||
if output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||
return ct.InfoFlow(
|
||||
dims={sym: None for sym in self.sorted_symbols},
|
||||
output=self.output_sym,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Constant
|
||||
## -> We only need the output symbol to describe the raw data.
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.InfoFlow(output=self.output_sym)
|
||||
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
case ct.FlowKind.Range if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.InfoFlow(
|
||||
dims={self.output_sym: self.lazy_range},
|
||||
output=self.output_sym.update(rows=1),
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind: Capabilities
|
||||
####################
|
||||
|
@ -1039,6 +1092,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI.
|
||||
"""
|
||||
if self.use_value_range_swapper:
|
||||
col.prop(self, self.blfields['selected_value_range'], text='')
|
||||
|
||||
if self.symbols:
|
||||
col.prop(self, self.blfields['raw_value_spstr'], text='')
|
||||
|
||||
|
@ -1097,6 +1153,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
|
||||
As such, `self.steps` won't be exposed in the UI.
|
||||
"""
|
||||
if self.use_value_range_swapper:
|
||||
col.prop(self, self.blfields['selected_value_range'], text='')
|
||||
|
||||
if self.symbols:
|
||||
col.prop(self, self.blfields['raw_min_spstr'], text='')
|
||||
col.prop(self, self.blfields['raw_max_spstr'], text='')
|
||||
|
@ -1198,13 +1257,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
# - Socket Configuration
|
||||
####################
|
||||
class ExprSocketDef(base.SocketDef):
|
||||
"""Interface for defining an `ExprSocket`."""
|
||||
|
||||
socket_type: ct.SocketType = ct.SocketType.Expr
|
||||
active_kind: typ.Literal[
|
||||
ct.FlowKind.Value,
|
||||
ct.FlowKind.Range,
|
||||
ct.FlowKind.Func,
|
||||
] = ct.FlowKind.Value
|
||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
|
||||
use_value_range_swapper: bool = False
|
||||
|
||||
# Socket Interface
|
||||
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
||||
|
@ -1458,7 +1515,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
# Check ActiveKind and Size
|
||||
## -> NOTE: This doesn't protect against dynamic changes to either.
|
||||
if (
|
||||
self.active_kind == ct.FlowKind.Range
|
||||
self.active_kind is ct.FlowKind.Range
|
||||
and self.size is not spux.NumberSize1D.Scalar
|
||||
):
|
||||
msg = "Can't have a non-Scalar size when Range is set as the active kind."
|
||||
|
@ -1504,9 +1561,9 @@ class ExprSocketDef(base.SocketDef):
|
|||
# - Initialization
|
||||
####################
|
||||
def init(self, bl_socket: ExprBLSocket) -> None:
|
||||
bl_socket.active_kind = self.active_kind
|
||||
bl_socket.output_name = self.output_name
|
||||
bl_socket.use_linked_capabilities = True
|
||||
bl_socket.use_value_range_swapper = self.use_value_range_swapper
|
||||
|
||||
# Socket Interface
|
||||
## -> Recall that auto-updates are turned off during init()
|
||||
|
@ -1543,6 +1600,25 @@ class ExprSocketDef(base.SocketDef):
|
|||
# Info Draw
|
||||
bl_socket.use_info_draw = True
|
||||
|
||||
def local_compare(self, bl_socket: ExprBLSocket) -> None:
|
||||
"""Determine whether an updateable socket should be re-initialized from this `SocketDef`."""
|
||||
|
||||
def cmp(attr: str):
|
||||
return getattr(bl_socket, attr) == getattr(self, attr)
|
||||
|
||||
return (
|
||||
bl_socket.use_linked_capabilities
|
||||
and cmp('output_name')
|
||||
and cmp('use_value_range_swapper')
|
||||
and cmp('size')
|
||||
and cmp('mathtype')
|
||||
and cmp('physical_type')
|
||||
and cmp('show_func_ui')
|
||||
and cmp('show_info_columns')
|
||||
and cmp('show_name_selector')
|
||||
and bl_socket.use_info_draw
|
||||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
|
|
|
@ -117,6 +117,16 @@ class MaxwellBoundCondsBLSocket(base.MaxwellSimSocket):
|
|||
),
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: self.value,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def params(self) -> ct.FuncFlow:
|
||||
return ct.ParamsFlow()
|
||||
|
||||
|
||||
####################
|
||||
# - Socket Configuration
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import typing as typ
|
||||
|
||||
from ... import contracts as ct
|
||||
from .. import base
|
||||
|
||||
|
@ -32,6 +34,9 @@ class MaxwellFDTDSimSocketDef(base.SocketDef):
|
|||
def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None:
|
||||
pass
|
||||
|
||||
def local_compare(self, _: MaxwellFDTDSimBLSocket) -> None:
|
||||
return True
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
|
|
|
@ -73,7 +73,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
|||
def value(self, eps_rel: tuple[float, float]) -> None:
|
||||
self.eps_rel = eps_rel
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'value', 'differentiable'})
|
||||
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: self.value,
|
||||
|
@ -82,7 +82,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
@bl_cache.cached_bl_property(depends_on={'differentiable'})
|
||||
def params(self) -> ct.FuncFlow:
|
||||
return ct.ParamsFlow(is_differentiable=self.differentiable)
|
||||
return ct.ParamsFlow()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellMonitorBLSocket(base.MaxwellSimSocket):
|
|||
class MaxwellMonitorSocketDef(base.SocketDef):
|
||||
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor
|
||||
|
||||
is_list: bool = False
|
||||
|
||||
def init(self, bl_socket: MaxwellMonitorBLSocket) -> None:
|
||||
if self.is_list:
|
||||
bl_socket.active_kind = ct.FlowKind.Array
|
||||
pass
|
||||
|
||||
def local_compare(self, _: MaxwellMonitorBLSocket) -> None:
|
||||
return True
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -55,6 +55,16 @@ class MaxwellSimGridBLSocket(base.MaxwellSimSocket):
|
|||
min_steps_per_wvl=self.min_steps_per_wl,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: self.value,
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def params(self) -> ct.FuncFlow:
|
||||
return ct.ParamsFlow()
|
||||
|
||||
|
||||
####################
|
||||
# - Socket Configuration
|
||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellSourceBLSocket(base.MaxwellSimSocket):
|
|||
class MaxwellSourceSocketDef(base.SocketDef):
|
||||
socket_type: ct.SocketType = ct.SocketType.MaxwellSource
|
||||
|
||||
is_list: bool = False
|
||||
|
||||
def init(self, bl_socket: MaxwellSourceBLSocket) -> None:
|
||||
if self.is_list:
|
||||
bl_socket.active_kind = ct.FlowKind.Array
|
||||
pass
|
||||
|
||||
def local_compare(self, _: MaxwellSourceBLSocket) -> None:
|
||||
return True
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellStructureBLSocket(base.MaxwellSimSocket):
|
|||
class MaxwellStructureSocketDef(base.SocketDef):
|
||||
socket_type: ct.SocketType = ct.SocketType.MaxwellStructure
|
||||
|
||||
is_list: bool = False
|
||||
|
||||
def init(self, bl_socket: MaxwellStructureBLSocket) -> None:
|
||||
if self.is_list:
|
||||
bl_socket.active_kind = ct.FlowKind.Array
|
||||
pass
|
||||
|
||||
def local_compare(self, _: MaxwellStructureBLSocket) -> None:
|
||||
return True
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -16,10 +16,10 @@
|
|||
|
||||
"""Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes."""
|
||||
|
||||
from ..keyed_cache import KeyedCache, keyed_cache
|
||||
from .bl_field import BLField
|
||||
from .bl_prop import BLProp, BLPropType
|
||||
from .cached_bl_property import CachedBLProperty, cached_bl_property
|
||||
from .keyed_cache import KeyedCache, keyed_cache
|
||||
from .managed_cache import invalidate_nonpersist_instance_id
|
||||
from .signal import Signal
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from types import MappingProxyType
|
|||
import bpy
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils.keyed_cache import keyed_cache
|
||||
|
||||
InstanceID: typ.TypeAlias = str ## Stringified UUID4
|
||||
|
||||
|
@ -220,11 +221,14 @@ class BLInstance:
|
|||
for str_search_prop_name in self.blfields_str_search:
|
||||
setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch)
|
||||
|
||||
@keyed_cache(
|
||||
exclude={'self'}, ## No dynamic elements of 'self' can be used.
|
||||
)
|
||||
def trace_blfields_to_clear(
|
||||
self,
|
||||
prop_name: str,
|
||||
prev_blfields_to_clear: list[
|
||||
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']]
|
||||
prev_blfields_to_clear: tuple[
|
||||
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']], ...
|
||||
] = (),
|
||||
) -> list[str]:
|
||||
"""Invalidates all properties that depend on `prop_name`.
|
||||
|
@ -239,7 +243,7 @@ class BLInstance:
|
|||
All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`).
|
||||
"""
|
||||
if prev_blfields_to_clear:
|
||||
blfields_to_clear = prev_blfields_to_clear.copy()
|
||||
blfields_to_clear = list(prev_blfields_to_clear)
|
||||
else:
|
||||
blfields_to_clear = []
|
||||
|
||||
|
@ -268,7 +272,7 @@ class BLInstance:
|
|||
if dst_prop_name in self.blfields:
|
||||
blfields_to_clear += self.trace_blfields_to_clear(
|
||||
dst_prop_name,
|
||||
prev_blfields_to_clear=blfields_to_clear,
|
||||
prev_blfields_to_clear=tuple(blfields_to_clear),
|
||||
)
|
||||
|
||||
match (bool(prev_blfields_to_clear), bool(blfields_to_clear)):
|
||||
|
@ -297,7 +301,7 @@ class BLInstance:
|
|||
## -> As such, deduplication would not be wrong, just extraneous.
|
||||
## -> Since invalidation is in a hot-loop, don't do such things.
|
||||
case (True, True):
|
||||
return blfields_to_clear
|
||||
return list(reversed(dict.fromkeys(reversed(blfields_to_clear))))
|
||||
|
||||
def clear_blfields_after(self, prop_name: str) -> list[str]:
|
||||
"""Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
"""Useful image processing operations for use in the addon."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import jax
|
||||
|
@ -26,7 +27,6 @@ import matplotlib
|
|||
import matplotlib.axis as mpl_ax
|
||||
import matplotlib.backends.backend_agg
|
||||
import matplotlib.figure
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
|
||||
from blender_maxwell import contracts as ct
|
||||
|
@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
|
|||
####################
|
||||
# - MPL Helpers
|
||||
####################
|
||||
# @functools.lru_cache(maxsize=16)
|
||||
@functools.lru_cache(maxsize=4)
|
||||
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
||||
fig = matplotlib.figure.Figure(
|
||||
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
|
||||
|
|
|
@ -18,11 +18,15 @@ import functools
|
|||
import inspect
|
||||
import typing as typ
|
||||
|
||||
from blender_maxwell.utils import bl_instance, logger, serialize
|
||||
from blender_maxwell.utils import logger, serialize
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class BLInstance(typ.Protocol):
|
||||
instance_id: str
|
||||
|
||||
|
||||
class KeyedCache:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -75,8 +79,8 @@ class KeyedCache:
|
|||
|
||||
def __get__(
|
||||
self,
|
||||
bl_instance: bl_instance.BLInstance | None,
|
||||
owner: type[bl_instance.BLInstance],
|
||||
bl_instance: BLInstance | None,
|
||||
owner: type[BLInstance],
|
||||
) -> typ.Callable:
|
||||
_func = functools.partial(self, bl_instance)
|
||||
_func.invalidate = functools.partial(
|
||||
|
@ -110,7 +114,7 @@ class KeyedCache:
|
|||
|
||||
def invalidate(
|
||||
self,
|
||||
bl_instance: bl_instance.BLInstance | None,
|
||||
bl_instance: BLInstance | None,
|
||||
**arguments: dict[str, typ.Any],
|
||||
) -> dict[str, typ.Any]:
|
||||
# Determine Wildcard Arguments
|
|
@ -264,13 +264,16 @@ class SimSymbol(pyd.BaseModel):
|
|||
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||
|
||||
####################
|
||||
# - Labels
|
||||
# - Core
|
||||
####################
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
"""Usable name for the symbol."""
|
||||
return self.sym_name.name
|
||||
|
||||
####################
|
||||
# - Labels
|
||||
####################
|
||||
@functools.cached_property
|
||||
def name_pretty(self) -> str:
|
||||
"""Pretty (possibly unicode) name for the thing."""
|
||||
|
@ -307,6 +310,8 @@ class SimSymbol(pyd.BaseModel):
|
|||
@functools.cached_property
|
||||
def plot_label(self) -> str:
|
||||
"""Pretty plot-oriented label."""
|
||||
if self.unit is None:
|
||||
return self.name_pretty
|
||||
return f'{self.name_pretty} ({self.unit_label})'
|
||||
|
||||
####################
|
||||
|
@ -420,6 +425,11 @@ class SimSymbol(pyd.BaseModel):
|
|||
|
||||
@functools.cached_property
|
||||
def is_nonzero(self) -> bool:
|
||||
"""Whether or not the value of this symbol can ever be $0$.
|
||||
|
||||
Notes:
|
||||
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
|
||||
"""
|
||||
if self.exclude_zero:
|
||||
return True
|
||||
|
||||
|
@ -441,6 +451,18 @@ class SimSymbol(pyd.BaseModel):
|
|||
)
|
||||
return check_real_domain(self.domain)
|
||||
|
||||
@functools.cached_property
|
||||
def can_diff(self) -> bool:
|
||||
"""Whether this symbol can be used as the input / output variable when differentiating."""
|
||||
# Check Constants
|
||||
## -> Constants (w/pinned values) are never differentiable.
|
||||
if self.is_constant:
|
||||
return False
|
||||
|
||||
# TODO: Discontinuities (especially across 0)?
|
||||
|
||||
return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
|
@ -664,8 +686,10 @@ class SimSymbol(pyd.BaseModel):
|
|||
res = spux.strip_unit_system(sp_obj)
|
||||
|
||||
# Broadcast Expansion
|
||||
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
|
||||
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
||||
if (self.rows > 1 or self.cols > 1) and not isinstance(
|
||||
res, sp.MatrixBase | sp.MatrixSymbol
|
||||
):
|
||||
res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
||||
|
||||
return res
|
||||
|
||||
|
@ -753,7 +777,9 @@ class SimSymbol(pyd.BaseModel):
|
|||
unit = None
|
||||
|
||||
# Rows/Cols from Expr (if Matrix)
|
||||
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
|
||||
rows, cols = (
|
||||
expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1)
|
||||
)
|
||||
|
||||
return SimSymbol(
|
||||
sym_name=sym_name,
|
||||
|
|
Loading…
Reference in New Issue