feat: various sym-flow modifications
parent
830b316e01
commit
38e70a60d3
|
@ -14,12 +14,12 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pydantic as pyd
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
@ -29,8 +29,7 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
|
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
class ArrayFlow(pyd.BaseModel):
|
||||||
class ArrayFlow:
|
|
||||||
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
|
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
|
||||||
|
|
||||||
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
|
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
|
||||||
|
@ -41,7 +40,9 @@ class ArrayFlow:
|
||||||
None if unitless.
|
None if unitless.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
values: jtyp.Shaped[jtyp.Array, '...']
|
model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
|
||||||
unit: spux.Unit | None = None
|
unit: spux.Unit | None = None
|
||||||
|
|
||||||
is_sorted: bool = False
|
is_sorted: bool = False
|
||||||
|
|
|
@ -18,6 +18,7 @@ import enum
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
|
from blender_maxwell.contracts import BLEnumElement
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import logger
|
||||||
from blender_maxwell.utils.staticproperty import staticproperty
|
from blender_maxwell.utils.staticproperty import staticproperty
|
||||||
|
@ -99,6 +100,17 @@ class FlowKind(enum.StrEnum):
|
||||||
def to_icon(_: typ.Self) -> str:
|
def to_icon(_: typ.Self) -> str:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return FlowKind.to_name(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
return FlowKind.to_icon(self)
|
||||||
|
|
||||||
|
def bl_enum_element(self, i) -> BLEnumElement:
|
||||||
|
return (str(self), self.name, self.name, self.icon, i)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Static Properties
|
# - Static Properties
|
||||||
####################
|
####################
|
||||||
|
@ -162,7 +174,7 @@ class FlowKind(enum.StrEnum):
|
||||||
def socket_shape(self) -> str:
|
def socket_shape(self) -> str:
|
||||||
"""Return the socket shape associated with this `FlowKind`.
|
"""Return the socket shape associated with this `FlowKind`.
|
||||||
|
|
||||||
**ONLY** valid for `FlowKind`s that can be considered "active".
|
Should generally only be used with `active_kinds`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If this `FlowKind` cannot ever be considered "active".
|
ValueError: If this `FlowKind` cannot ever be considered "active".
|
||||||
|
@ -172,7 +184,7 @@ class FlowKind(enum.StrEnum):
|
||||||
FlowKind.Array: 'SQUARE',
|
FlowKind.Array: 'SQUARE',
|
||||||
FlowKind.Range: 'SQUARE',
|
FlowKind.Range: 'SQUARE',
|
||||||
FlowKind.Func: 'DIAMOND',
|
FlowKind.Func: 'DIAMOND',
|
||||||
}[self]
|
}.get(self, 'CIRCLE')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Class Methods
|
# - Class Methods
|
||||||
|
|
|
@ -14,13 +14,217 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import dataclasses
|
r"""Implements the core of the math system via `FuncFlow`, which allows high-performance, fully-expressive workflows with data that can be "very large", and/or whose input parameters are not yet fully known.
|
||||||
|
|
||||||
|
# Introduction
|
||||||
|
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
|
||||||
|
Doing so has several advantages:
|
||||||
|
|
||||||
|
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy, greatly boosting the will to experiment and ultimate productivity.
|
||||||
|
- **Symbolic**: Since no numerical math is being done yet, we can inject symbolic variables at-will, enabling effortless ex. parameter sweeping, band-structure generation, differentiable can choose to keep our input parameters as symbolic variables with no performance impact.
|
||||||
|
- **Performant**: Since the data pipeline is built as a single function w/o side effects, that function can be often be JIT-optimized for highly-optimized execution on the instruction sets used by modern massively-parallel devices, like modern CPUs (SSE/AVX), GPUs (ex. PTX), and HPC clusters (network-sharding to ARM/x86).
|
||||||
|
|
||||||
|
The result is a math system optimized for the analysis typically needed in electrodynamic contexts, prioritizing clarity and flexibility at soft-real-time, even with gigabytes of data on relatively weak hardware.
|
||||||
|
|
||||||
|
## Strongly Related FlowKinds
|
||||||
|
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
|
||||||
|
|
||||||
|
- `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`.
|
||||||
|
- `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic.
|
||||||
|
|
||||||
|
For more, please see the documentation for each.
|
||||||
|
|
||||||
|
## Non-Mathematical Use
|
||||||
|
Of course, there are many interesting uses of incremental function composition that aren't mathematical.
|
||||||
|
|
||||||
|
For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy Math: Theoretical Foundation
|
||||||
|
This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math.
|
||||||
|
Thus, it makes sense to describe the math system here.
|
||||||
|
|
||||||
|
## `depth=0`: Root Function
|
||||||
|
To start a composition chain, a function with no inputs must be defined as the "root", or "bottom".
|
||||||
|
|
||||||
|
$$
|
||||||
|
f_0:\ \ \ \ \biggl(
|
||||||
|
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
|
||||||
|
\underbrace{
|
||||||
|
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
|
||||||
|
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
|
||||||
|
...,
|
||||||
|
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
|
||||||
|
}_{\texttt{kwargs}}
|
||||||
|
\biggr) \to \text{output}_0
|
||||||
|
$$
|
||||||
|
|
||||||
|
In Python, such a construction would look like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
|
||||||
|
## 'A0', 'KV0' are of length 'p' and 'q'
|
||||||
|
def f_0(*args, **kwargs): ...
|
||||||
|
|
||||||
|
lazy_func_0 = FuncFlow(
|
||||||
|
func=f_0,
|
||||||
|
func_args=[(a_i, type(a_i)) for a_i in A0],
|
||||||
|
func_kwargs={k: v for k,v in KV0},
|
||||||
|
)
|
||||||
|
output_0 = lazy_func.func(*A0_computed, **KV0_computed)
|
||||||
|
```
|
||||||
|
|
||||||
|
## `depth>0`: Composition Chaining
|
||||||
|
So far, so easy.
|
||||||
|
Now, let's add a function that uses the result of $f_0$, without yet computing it.
|
||||||
|
|
||||||
|
$$
|
||||||
|
f_1:\ \ \ \ \biggl(
|
||||||
|
f_0(...),\ \
|
||||||
|
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
|
||||||
|
\underbrace{\biggl\{
|
||||||
|
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
|
||||||
|
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
|
||||||
|
\biggr) \to \text{output}_1
|
||||||
|
$$
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- $f_1$ must take the arguments of both $f_0$ and $f_1$.
|
||||||
|
- The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments".
|
||||||
|
|
||||||
|
In other words, **there's suddenly a lot to manage**.
|
||||||
|
Even worse, the bigger the $n$, the more complexity we must real with.
|
||||||
|
|
||||||
|
This is where the Python version starts to show its purpose:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
|
||||||
|
## 'A1', 'KV1' are therefore of length 'r' and 's'
|
||||||
|
def f_1(output_0, *args, **kwargs): ...
|
||||||
|
|
||||||
|
lazy_func_1 = lazy_func_0.compose_within(
|
||||||
|
enclosing_func=f_1,
|
||||||
|
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
|
||||||
|
enclosing_func_kwargs={k: type(v) for k,v in K1},
|
||||||
|
)
|
||||||
|
|
||||||
|
A_computed = A0_computed + A1_computed
|
||||||
|
KW_computed = KV0_computed + KV1_computed
|
||||||
|
output_1 = lazy_func_1.func(*A_computed, **KW_computed)
|
||||||
|
```
|
||||||
|
|
||||||
|
By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this.
|
||||||
|
|
||||||
|
## `max depth`: "Realization"
|
||||||
|
So, we've composed a bunch of functions of functions of ...
|
||||||
|
We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object.
|
||||||
|
|
||||||
|
But it'd be pointless to just compose away forever.
|
||||||
|
We do actually need the data that they claim to compute now:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# A_all and KW_all must be tracked on the side.
|
||||||
|
output_n = lazy_func_n.func(*A_all, **KW_all)
|
||||||
|
```
|
||||||
|
|
||||||
|
Of course, this comes with enormous overhead.
|
||||||
|
Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially.
|
||||||
|
|
||||||
|
That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all:
|
||||||
|
|
||||||
|
```python
|
||||||
|
output_n = lazy_func_n.func_jax(*A_all, **KW_all)
|
||||||
|
```
|
||||||
|
|
||||||
|
What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU.
|
||||||
|
With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices.
|
||||||
|
|
||||||
|
See the documentation of `Func.func_jax()` for more information on this process.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy Math: Practical Considerations
|
||||||
|
By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance.
|
||||||
|
|
||||||
|
## UX
|
||||||
|
UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective.
|
||||||
|
|
||||||
|
The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_.
|
||||||
|
|
||||||
|
Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**.
|
||||||
|
In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds.
|
||||||
|
Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software.
|
||||||
|
|
||||||
|
Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided.
|
||||||
|
In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough.
|
||||||
|
**This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need.
|
||||||
|
|
||||||
|
For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it.
|
||||||
|
It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis.
|
||||||
|
This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less.
|
||||||
|
|
||||||
|
Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes.
|
||||||
|
For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works).
|
||||||
|
To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises.
|
||||||
|
As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag.
|
||||||
|
|
||||||
|
In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters).
|
||||||
|
Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons.
|
||||||
|
With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better.
|
||||||
|
|
||||||
|
## Flexibility
|
||||||
|
Large-scale math is done on tensors, whether one knows (or likes!) it or not.
|
||||||
|
To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do:
|
||||||
|
|
||||||
|
- **Naming**: What _is_ each axis?
|
||||||
|
Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays.
|
||||||
|
Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted.
|
||||||
|
_Sometimes arrays with named axes are called "structured arrays".
|
||||||
|
|
||||||
|
- **Coordinates**: What do the indices of each axis really _mean_?
|
||||||
|
For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$.
|
||||||
|
But they can be limited to between index `23` to `298`.
|
||||||
|
I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough.
|
||||||
|
|
||||||
|
Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible.
|
||||||
|
|
||||||
|
In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain.
|
||||||
|
**Please consult the `InfoFlow` documentation for more**.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`.
|
||||||
|
|
||||||
|
It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes.
|
||||||
|
Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data.
|
||||||
|
|
||||||
|
See the `ParamFlow` documentation for more information.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Conclusion
|
||||||
|
There is, of course, a lot more to say about the math system in general.
|
||||||
|
A few teasers of what nodes can do with this system:
|
||||||
|
|
||||||
|
**Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`.
|
||||||
|
`jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability.
|
||||||
|
When used in
|
||||||
|
**Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes
|
||||||
|
We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations.
|
||||||
|
**Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_.
|
||||||
|
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
|
||||||
|
|
||||||
|
But above all, we hope that this math system is fun, practical, and maybe even interesting.
|
||||||
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
|
import pydantic as pyd
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger, sim_symbols
|
from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
@ -35,210 +239,12 @@ log = logger.get(__name__)
|
||||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
class FuncFlow(pyd.BaseModel):
|
||||||
class FuncFlow:
|
|
||||||
r"""Defines a flow of data as incremental function composition.
|
r"""Defines a flow of data as incremental function composition.
|
||||||
|
|
||||||
|
For theoretical information, please see the documentation of this module.
|
||||||
For specific math system usage instructions, please consult the documentation of relevant nodes.
|
For specific math system usage instructions, please consult the documentation of relevant nodes.
|
||||||
|
|
||||||
# Introduction
|
|
||||||
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
|
|
||||||
Doing so has several advantages:
|
|
||||||
|
|
||||||
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy.
|
|
||||||
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact.
|
|
||||||
- **Performant**: Since no operations are happening, the UI feels fast and snappy.
|
|
||||||
|
|
||||||
## Strongly Related FlowKinds
|
|
||||||
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
|
|
||||||
|
|
||||||
- `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`.
|
|
||||||
- `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic.
|
|
||||||
|
|
||||||
For more, please see the documentation for each.
|
|
||||||
|
|
||||||
## Non-Mathematical Use
|
|
||||||
Of course, there are many interesting uses of incremental function composition that aren't mathematical.
|
|
||||||
|
|
||||||
For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Lazy Math: Theoretical Foundation
|
|
||||||
This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math.
|
|
||||||
Thus, it makes sense to describe the math system here.
|
|
||||||
|
|
||||||
## `depth=0`: Root Function
|
|
||||||
To start a composition chain, a function with no inputs must be defined as the "root", or "bottom".
|
|
||||||
|
|
||||||
$$
|
|
||||||
f_0:\ \ \ \ \biggl(
|
|
||||||
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
|
|
||||||
\underbrace{
|
|
||||||
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
|
|
||||||
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
|
|
||||||
...,
|
|
||||||
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
|
|
||||||
}_{\texttt{kwargs}}
|
|
||||||
\biggr) \to \text{output}_0
|
|
||||||
$$
|
|
||||||
|
|
||||||
In Python, such a construction would look like this:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
|
|
||||||
## 'A0', 'KV0' are of length 'p' and 'q'
|
|
||||||
def f_0(*args, **kwargs): ...
|
|
||||||
|
|
||||||
lazy_func_0 = FuncFlow(
|
|
||||||
func=f_0,
|
|
||||||
func_args=[(a_i, type(a_i)) for a_i in A0],
|
|
||||||
func_kwargs={k: v for k,v in KV0},
|
|
||||||
)
|
|
||||||
output_0 = lazy_func.func(*A0_computed, **KV0_computed)
|
|
||||||
```
|
|
||||||
|
|
||||||
## `depth>0`: Composition Chaining
|
|
||||||
So far, so easy.
|
|
||||||
Now, let's add a function that uses the result of $f_0$, without yet computing it.
|
|
||||||
|
|
||||||
$$
|
|
||||||
f_1:\ \ \ \ \biggl(
|
|
||||||
f_0(...),\ \
|
|
||||||
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
|
|
||||||
\underbrace{\biggl\{
|
|
||||||
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
|
|
||||||
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
|
|
||||||
\biggr) \to \text{output}_1
|
|
||||||
$$
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- $f_1$ must take the arguments of both $f_0$ and $f_1$.
|
|
||||||
- The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments".
|
|
||||||
|
|
||||||
In other words, **there's suddenly a lot to manage**.
|
|
||||||
Even worse, the bigger the $n$, the more complexity we must real with.
|
|
||||||
|
|
||||||
This is where the Python version starts to show its purpose:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
|
|
||||||
## 'A1', 'KV1' are therefore of length 'r' and 's'
|
|
||||||
def f_1(output_0, *args, **kwargs): ...
|
|
||||||
|
|
||||||
lazy_func_1 = lazy_func_0.compose_within(
|
|
||||||
enclosing_func=f_1,
|
|
||||||
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
|
|
||||||
enclosing_func_kwargs={k: type(v) for k,v in K1},
|
|
||||||
)
|
|
||||||
|
|
||||||
A_computed = A0_computed + A1_computed
|
|
||||||
KW_computed = KV0_computed + KV1_computed
|
|
||||||
output_1 = lazy_func_1.func(*A_computed, **KW_computed)
|
|
||||||
```
|
|
||||||
|
|
||||||
By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this.
|
|
||||||
|
|
||||||
## `max depth`: "Realization"
|
|
||||||
So, we've composed a bunch of functions of functions of ...
|
|
||||||
We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object.
|
|
||||||
|
|
||||||
But it'd be pointless to just compose away forever.
|
|
||||||
We do actually need the data that they claim to compute now:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# A_all and KW_all must be tracked on the side.
|
|
||||||
output_n = lazy_func_n.func(*A_all, **KW_all)
|
|
||||||
```
|
|
||||||
|
|
||||||
Of course, this comes with enormous overhead.
|
|
||||||
Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially.
|
|
||||||
|
|
||||||
That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all:
|
|
||||||
|
|
||||||
```python
|
|
||||||
output_n = lazy_func_n.func_jax(*A_all, **KW_all)
|
|
||||||
```
|
|
||||||
|
|
||||||
What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU.
|
|
||||||
With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices.
|
|
||||||
|
|
||||||
See the documentation of `Func.func_jax()` for more information on this process.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Lazy Math: Practical Considerations
|
|
||||||
By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance.
|
|
||||||
|
|
||||||
## UX
|
|
||||||
UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective.
|
|
||||||
|
|
||||||
The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_.
|
|
||||||
|
|
||||||
Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**.
|
|
||||||
In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds.
|
|
||||||
Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software.
|
|
||||||
|
|
||||||
Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided.
|
|
||||||
In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough.
|
|
||||||
**This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need.
|
|
||||||
|
|
||||||
For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it.
|
|
||||||
It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis.
|
|
||||||
This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less.
|
|
||||||
|
|
||||||
Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes.
|
|
||||||
For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works).
|
|
||||||
To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises.
|
|
||||||
As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag.
|
|
||||||
|
|
||||||
In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters).
|
|
||||||
Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons.
|
|
||||||
With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better.
|
|
||||||
|
|
||||||
## Flexibility
|
|
||||||
Large-scale math is done on tensors, whether one knows (or likes!) it or not.
|
|
||||||
To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do:
|
|
||||||
|
|
||||||
- **Naming**: What _is_ each axis?
|
|
||||||
Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays.
|
|
||||||
Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted.
|
|
||||||
_Sometimes arrays with named axes are called "structured arrays".
|
|
||||||
|
|
||||||
- **Coordinates**: What do the indices of each axis really _mean_?
|
|
||||||
For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$.
|
|
||||||
But they can be limited to between index `23` to `298`.
|
|
||||||
I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough.
|
|
||||||
|
|
||||||
Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible.
|
|
||||||
|
|
||||||
In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain.
|
|
||||||
**Please consult the `InfoFlow` documentation for more**.
|
|
||||||
|
|
||||||
## Performance
|
|
||||||
All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`.
|
|
||||||
|
|
||||||
It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes.
|
|
||||||
Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data.
|
|
||||||
|
|
||||||
See the `ParamFlow` documentation for more information.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Conclusion
|
|
||||||
There is, of course, a lot more to say about the math system in general.
|
|
||||||
A few teasers of what nodes can do with this system:
|
|
||||||
|
|
||||||
**Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`.
|
|
||||||
`jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability.
|
|
||||||
When used in
|
|
||||||
**Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes
|
|
||||||
We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations.
|
|
||||||
**Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_.
|
|
||||||
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
|
|
||||||
|
|
||||||
But above all, we hope that this math system is fun, practical, and maybe even interesting.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
func: The function that generates the represented value.
|
func: The function that generates the represented value.
|
||||||
func_args: The constrained identity of all positional arguments to the function.
|
func_args: The constrained identity of all positional arguments to the function.
|
||||||
|
@ -247,14 +253,16 @@ class FuncFlow:
|
||||||
See the documentation of `self.func_jax()`.
|
See the documentation of `self.func_jax()`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_config = pyd.ConfigDict(frozen=True)
|
||||||
|
|
||||||
func: LazyFunction
|
func: LazyFunction
|
||||||
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
|
||||||
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
|
func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
|
||||||
default_factory=dict
|
func_output: sim_symbols.SimSymbol | None = None
|
||||||
)
|
|
||||||
supports_jax: bool = False
|
supports_jax: bool = False
|
||||||
|
|
||||||
concatenated: bool = False
|
is_concatenated: bool = False
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Functions
|
# - Functions
|
||||||
|
@ -318,6 +326,7 @@ class FuncFlow:
|
||||||
{}
|
{}
|
||||||
),
|
),
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
|
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
|
||||||
if self.supports_jax:
|
if self.supports_jax:
|
||||||
return self.func_jax(
|
return self.func_jax(
|
||||||
*params.scaled_func_args(symbol_values),
|
*params.scaled_func_args(symbol_values),
|
||||||
|
@ -371,14 +380,55 @@ class FuncFlow:
|
||||||
|
|
||||||
return data | {info.output: self.realize(params, symbol_values=symbol_values)}
|
return data | {info.output: self.realize(params, symbol_values=symbol_values)}
|
||||||
|
|
||||||
|
def realize_partial(
|
||||||
|
self, params: ParamsFlow
|
||||||
|
) -> typ.Callable[
|
||||||
|
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||||
|
jtyp.Inexact[jtyp.Array, '...'],
|
||||||
|
]:
|
||||||
|
"""Create a purely-numerical function, which takes only numerical.
|
||||||
|
|
||||||
|
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
|
||||||
|
|
||||||
|
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
|
||||||
|
By using `realize_partial()`, two things are ensured:
|
||||||
|
|
||||||
|
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
|
||||||
|
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Be **very careful about units**.
|
||||||
|
Ideally, the bottom function should use `.scale_to_unit()` before invoking `.compose_within()` with the output of this function.
|
||||||
|
"""
|
||||||
|
pre_realized_syms = list(
|
||||||
|
params.realize_symbols(params.realized_symbols, allow_partial=True).values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def realizer(
|
||||||
|
*sym_args: int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||||
|
) -> jtyp.Inexact[jtyp.Array, '...']:
|
||||||
|
return self.func(
|
||||||
|
*[
|
||||||
|
func_arg_n(*sym_args, *pre_realized_syms)
|
||||||
|
for func_arg_n in params.func_args_n
|
||||||
|
],
|
||||||
|
**{
|
||||||
|
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
|
||||||
|
for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return realizer
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Composition Operations
|
# - Operations
|
||||||
####################
|
####################
|
||||||
def compose_within(
|
def compose_within(
|
||||||
self,
|
self,
|
||||||
enclosing_func: LazyFunction,
|
enclosing_func: LazyFunction,
|
||||||
enclosing_func_args: list[type] = (),
|
enclosing_func_args: list[sim_symbols.SimSymbol] = (),
|
||||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
|
||||||
|
enclosing_func_output: sim_symbols.SimSymbol | None = None,
|
||||||
supports_jax: bool = False,
|
supports_jax: bool = False,
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
|
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
|
||||||
|
@ -415,6 +465,10 @@ class FuncFlow:
|
||||||
Returns:
|
Returns:
|
||||||
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
|
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
|
||||||
"""
|
"""
|
||||||
|
## TODO: Support unit system conversion at the point of composition.
|
||||||
|
## -- This may require us to track the units of the function output.
|
||||||
|
## TODO: Support JAX-evaluation when jax support changes from True to False.
|
||||||
|
## -- This would allow big data flows to compose performantly as arguments into non-JAX functions.
|
||||||
return FuncFlow(
|
return FuncFlow(
|
||||||
func=lambda *args, **kwargs: enclosing_func(
|
func=lambda *args, **kwargs: enclosing_func(
|
||||||
self.func(
|
self.func(
|
||||||
|
@ -426,6 +480,7 @@ class FuncFlow:
|
||||||
),
|
),
|
||||||
func_args=self.func_args + list(enclosing_func_args),
|
func_args=self.func_args + list(enclosing_func_args),
|
||||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||||
|
func_output=enclosing_func_output,
|
||||||
supports_jax=self.supports_jax and supports_jax,
|
supports_jax=self.supports_jax and supports_jax,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -472,7 +527,7 @@ class FuncFlow:
|
||||||
*list(args[: len(self.func_args)]),
|
*list(args[: len(self.func_args)]),
|
||||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||||
)
|
)
|
||||||
if not self.concatenated:
|
if not self.is_concatenated:
|
||||||
return (ret,)
|
return (ret,)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@ -487,5 +542,83 @@ class FuncFlow:
|
||||||
func_args=self.func_args + other.func_args,
|
func_args=self.func_args + other.func_args,
|
||||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
supports_jax=self.supports_jax and other.supports_jax,
|
supports_jax=self.supports_jax and other.supports_jax,
|
||||||
concatenated=True,
|
is_concatenated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
|
||||||
|
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
|
||||||
|
|
||||||
|
`unit` must be manually guaranteed to be compatible with `self.unit`.
|
||||||
|
"""
|
||||||
|
if self.func_output is not None:
|
||||||
|
# Retrieve Output Unit
|
||||||
|
output_unit = self.func_output.unit
|
||||||
|
|
||||||
|
# Compile Efficient Unit-Conversion Function
|
||||||
|
a = self.func_output.mathtype.sp_symbol_a
|
||||||
|
unit_convert_expr = (
|
||||||
|
spux.scale_to_unit(a * output_unit, unit)
|
||||||
|
if self.func_output.unit is not None
|
||||||
|
else a
|
||||||
|
)
|
||||||
|
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
|
||||||
|
|
||||||
|
# Compose Unit-Converted FuncFlow
|
||||||
|
return self.compose_within(
|
||||||
|
enclosing_func=unit_convert_func,
|
||||||
|
supports_jax=True,
|
||||||
|
enclosing_func_output=self.func_output.update(unit=unit),
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def scale_to_unit_system(
|
||||||
|
self, unit_system: spux.UnitSystem | None = None
|
||||||
|
) -> typ.Self:
|
||||||
|
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
|
||||||
|
|
||||||
|
Using `self.output_symbol`, which tracks the units of the output, we can determine a scaling factor to multiply the (numerical) function output by in order to conform it to the given unit system.
|
||||||
|
|
||||||
|
In general, **don't use this**.
|
||||||
|
Any superfluous numerical operations in a data pipeline can enhance instabilities and interfere with JIT-optimization (floating-point arithmetic isn't commutative, for example).
|
||||||
|
However, occasionally, we need to "intercept" a lazy data flow, for example when realizing a `FlowKind.Value` that doesn't understand symbols or units - but which only accepts a float/complex scalar/array with pre-determined unit convention.
|
||||||
|
|
||||||
|
For this purpose alone, this method is provided to pre-scale a `FuncFlow`, just before using `realize()` / `__or__` and then `realize()`.
|
||||||
|
**To encourage proper usage** (and ease implementation), the output unit in `self.func_output` of the output will be reset to `None` - indicating that the output can only be handled as a unitless scalar w/semantic meaning tracked elsewhere.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
**ONLY** use with output types that support meaningful arbitrary multiplication.
|
||||||
|
|
||||||
|
A scale-only sympy expression will be used to produce an optimized JAX function of a single variable, which will then be composed onto the existing `FuncFlow`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unit_system: The unit system to conform the function output to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new `FuncFlow` that conforms to the new unit, but is itself now considered unitless.
|
||||||
|
"""
|
||||||
|
if self.func_output is not None:
|
||||||
|
# Retrieve Output Unit
|
||||||
|
output_unit = self.func_output.unit
|
||||||
|
|
||||||
|
# Compile Efficient Unit-Conversion Function
|
||||||
|
a = self.func_output.mathtype.sp_symbol_a
|
||||||
|
unit_convert_expr = (
|
||||||
|
spux.strip_unit_system(
|
||||||
|
spux.convert_to_unit_system(a * output_unit, unit_system)
|
||||||
|
)
|
||||||
|
if self.func_output.unit is not None
|
||||||
|
else a
|
||||||
|
)
|
||||||
|
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
|
||||||
|
|
||||||
|
# Compose Unit-Converted FuncFlow
|
||||||
|
return self.compose_within(
|
||||||
|
enclosing_func=unit_convert_func,
|
||||||
|
supports_jax=True,
|
||||||
|
enclosing_func_output=self.func_output.update(unit=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
|
@ -14,17 +14,15 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
from fractions import Fraction
|
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
|
import pydantic as pyd
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import sympy.physics.units as spu
|
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger, sim_symbols
|
from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
@ -61,8 +59,7 @@ class ScalingMode(enum.StrEnum):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
class RangeFlow(pyd.BaseModel):
|
||||||
class RangeFlow:
|
|
||||||
r"""Represents a finite spaced array using symbolic boundary expressions.
|
r"""Represents a finite spaced array using symbolic boundary expressions.
|
||||||
|
|
||||||
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||||
|
@ -92,8 +89,10 @@ class RangeFlow:
|
||||||
symbols: Set of variables from which `start` and/or `stop` are determined.
|
symbols: Set of variables from which `start` and/or `stop` are determined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start: spux.ScalarUnitlessComplexExpr
|
model_config = pyd.ConfigDict(frozen=True)
|
||||||
stop: spux.ScalarUnitlessComplexExpr
|
|
||||||
|
start: spux.ScalarUnitlessRealExpr
|
||||||
|
stop: spux.ScalarUnitlessRealExpr
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
scaling: ScalingMode = ScalingMode.Lin
|
scaling: ScalingMode = ScalingMode.Lin
|
||||||
|
|
||||||
|
@ -102,7 +101,7 @@ class RangeFlow:
|
||||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||||
|
|
||||||
# Helper Attributes
|
# Helper Attributes
|
||||||
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
|
pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - SimSymbol Interop
|
# - SimSymbol Interop
|
||||||
|
@ -218,14 +217,26 @@ class RangeFlow:
|
||||||
)
|
)
|
||||||
return combined_mathtype
|
return combined_mathtype
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
def ideal_midpoint(self) -> spux.SympyExpr:
|
def ideal_midpoint(self) -> spux.SympyExpr:
|
||||||
return (self.stop + self.start) / 2
|
return (self.stop + self.start) / 2
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
def ideal_range(self) -> spux.SympyExpr:
|
def ideal_range(self) -> spux.SympyExpr:
|
||||||
return self.stop - self.start
|
return self.stop - self.start
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def ideal_step_size(self) -> spux.SympyExpr:
|
||||||
|
return self.ideal_range / (self.steps - 1)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def is_always_nonzero(self) -> spux.SympyExpr:
|
||||||
|
if self.start > 0 or self.stop < 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
is_zero = (self.start % self.ideal_step_size).is_zero
|
||||||
|
return is_zero if is_zero is not None else False
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Methods
|
# - Methods
|
||||||
####################
|
####################
|
||||||
|
@ -452,7 +463,7 @@ class RangeFlow:
|
||||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
{}
|
{}
|
||||||
),
|
),
|
||||||
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
|
) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
|
||||||
"""Realize **all** input symbols to the `RangeFlow`.
|
"""Realize **all** input symbols to the `RangeFlow`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -480,7 +491,7 @@ class RangeFlow:
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
realized_syms |= {sym: v}
|
realized_syms |= {sym: v}
|
||||||
|
return realized_syms
|
||||||
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
|
import pydantic as pyd
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
@ -28,31 +28,35 @@ from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
|
||||||
from .array import ArrayFlow
|
from .array import ArrayFlow
|
||||||
from .expr_info import ExprInfo
|
from .expr_info import ExprInfo
|
||||||
from .flow_kinds import FlowKind
|
|
||||||
from .lazy_range import RangeFlow
|
from .lazy_range import RangeFlow
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
class ParamsFlow(pyd.BaseModel):
|
||||||
class ParamsFlow:
|
|
||||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
All symbols valid for use in the expression.
|
All symbols valid for use in the expression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_targets: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
model_config = pyd.ConfigDict(frozen=True)
|
||||||
kwarg_targets: list[str, sim_symbols.SimSymbol] = dataclasses.field(
|
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
|
||||||
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
|
||||||
|
|
||||||
|
func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
|
||||||
|
func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
|
||||||
|
|
||||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||||
|
realized_symbols: dict[
|
||||||
|
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||||
|
] = pyd.Field(default_factory=dict)
|
||||||
|
|
||||||
is_differentiable: bool = False
|
@functools.cached_property
|
||||||
|
def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
|
||||||
|
"""Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
|
||||||
|
return {sym for sym in self.symbols if sym.can_diff}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Symbols
|
# - Symbols
|
||||||
|
@ -78,6 +82,27 @@ class ParamsFlow:
|
||||||
"""
|
"""
|
||||||
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def all_sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||||
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
key_func = lambda sym: sym.name # noqa: E731
|
||||||
|
return sorted(self.symbols, key=key_func) + sorted(
|
||||||
|
self.realized_symbols.keys(), key=key_func
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def all_sorted_sp_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||||
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
return [sym.sp_symbol_matsym for sym in self.all_sorted_symbols]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - JIT'ed Callables for Numerical Function Arguments
|
# - JIT'ed Callables for Numerical Function Arguments
|
||||||
####################
|
####################
|
||||||
|
@ -101,7 +126,7 @@ class ParamsFlow:
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
sp.lambdify(
|
sp.lambdify(
|
||||||
self.sorted_sp_symbols,
|
self.all_sorted_sp_symbols,
|
||||||
target_sym.conform(func_arg, strip_unit=True),
|
target_sym.conform(func_arg, strip_unit=True),
|
||||||
'jax',
|
'jax',
|
||||||
)
|
)
|
||||||
|
@ -127,7 +152,7 @@ class ParamsFlow:
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
key: sp.lambdify(
|
key: sp.lambdify(
|
||||||
self.sorted_sp_symbols,
|
self.all_sorted_sp_symbols,
|
||||||
self.kwarg_targets[key].conform(func_arg, strip_unit=True),
|
self.kwarg_targets[key].conform(func_arg, strip_unit=True),
|
||||||
'jax',
|
'jax',
|
||||||
)
|
)
|
||||||
|
@ -142,8 +167,9 @@ class ParamsFlow:
|
||||||
symbol_values: dict[
|
symbol_values: dict[
|
||||||
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||||
] = MappingProxyType({}),
|
] = MappingProxyType({}),
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> dict[
|
) -> dict[
|
||||||
sp.Symbol,
|
sim_symbols.SimSymbol,
|
||||||
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
|
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
|
||||||
]:
|
]:
|
||||||
"""Fully realize all symbols by assigning them a value.
|
"""Fully realize all symbols by assigning them a value.
|
||||||
|
@ -160,10 +186,12 @@ class ParamsFlow:
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary almost with `.subs()`, other than `jax` arrays.
|
A dictionary almost with `.subs()`, other than `jax` arrays.
|
||||||
"""
|
"""
|
||||||
if set(self.symbols) == set(symbol_values.keys()):
|
if allow_partial or set(self.all_sorted_symbols) == set(symbol_values.keys()):
|
||||||
realized_syms = {}
|
realized_syms = {}
|
||||||
for sym in self.sorted_symbols:
|
for sym in self.all_sorted_symbols:
|
||||||
sym_value = symbol_values[sym]
|
sym_value = symbol_values.get(sym)
|
||||||
|
if sym_value is None and allow_partial:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(sym_value, spux.SympyType):
|
if isinstance(sym_value, spux.SympyType):
|
||||||
v = sym.scale(sym_value)
|
v = sym.scale(sym_value)
|
||||||
|
@ -214,7 +242,9 @@ class ParamsFlow:
|
||||||
Parameters:
|
Parameters:
|
||||||
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
|
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
|
||||||
"""
|
"""
|
||||||
realized_symbols = list(self.realize_symbols(symbol_values).values())
|
realized_symbols = list(
|
||||||
|
self.realize_symbols(symbol_values | self.realized_symbols).values()
|
||||||
|
)
|
||||||
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
|
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
|
||||||
|
|
||||||
def scaled_func_kwargs(
|
def scaled_func_kwargs(
|
||||||
|
@ -227,10 +257,11 @@ class ParamsFlow:
|
||||||
|
|
||||||
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
|
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
|
||||||
"""
|
"""
|
||||||
realized_symbols = self.realize_symbols(symbol_values)
|
realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
func_arg_name: func_arg_n(**realized_symbols)
|
func_arg_name: func_kwarg_n(**realized_symbols)
|
||||||
for func_arg_name, func_arg_n in self.func_kwargs_n.items()
|
for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -251,7 +282,7 @@ class ParamsFlow:
|
||||||
func_args=self.func_args + other.func_args,
|
func_args=self.func_args + other.func_args,
|
||||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
symbols=self.symbols | other.symbols,
|
symbols=self.symbols | other.symbols,
|
||||||
is_differentiable=self.is_differentiable and other.is_differentiable,
|
realized_symbols=self.realized_symbols | other.realized_symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compose_within(
|
def compose_within(
|
||||||
|
@ -261,7 +292,6 @@ class ParamsFlow:
|
||||||
enclosing_func_args: list[spux.SympyExpr] = (),
|
enclosing_func_args: list[spux.SympyExpr] = (),
|
||||||
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
|
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
|
||||||
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
|
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
|
||||||
enclosing_is_differentiable: bool = False,
|
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
return ParamsFlow(
|
return ParamsFlow(
|
||||||
arg_targets=self.arg_targets + list(enclosing_arg_targets),
|
arg_targets=self.arg_targets + list(enclosing_arg_targets),
|
||||||
|
@ -269,13 +299,41 @@ class ParamsFlow:
|
||||||
func_args=self.func_args + list(enclosing_func_args),
|
func_args=self.func_args + list(enclosing_func_args),
|
||||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||||
symbols=self.symbols | enclosing_symbols,
|
symbols=self.symbols | enclosing_symbols,
|
||||||
is_differentiable=(
|
realized_symbols=self.realized_symbols,
|
||||||
self.is_differentiable
|
|
||||||
if not enclosing_symbols
|
|
||||||
else (self.is_differentiable & enclosing_is_differentiable)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def realize_partial(
|
||||||
|
self,
|
||||||
|
symbol_values: dict[
|
||||||
|
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||||
|
],
|
||||||
|
) -> typ.Self:
|
||||||
|
"""Provide a particular expression/range/array to realize some symbols.
|
||||||
|
|
||||||
|
Essentially removes symbols from `self.symbols`, and adds the symbol w/value to `self.realized_symbols`.
|
||||||
|
As a result, only the still-unrealized symbols need to be passed at the time of realization (using ex. `self.scaled_func_args()`).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
symbol_values: The value to realize for each `SimSymbol`.
|
||||||
|
**All keys** must be identically matched to a single element of `self.symbol`.
|
||||||
|
Can be empty, in which case an identical new `ParamsFlow` will be returned.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any symbol in `symbol_values`
|
||||||
|
"""
|
||||||
|
syms = set(symbol_values.keys())
|
||||||
|
if syms.issubset(self.symbols) or not syms:
|
||||||
|
return ParamsFlow(
|
||||||
|
arg_targets=self.arg_targets,
|
||||||
|
kwarg_targets=self.kwarg_targets,
|
||||||
|
func_args=self.func_args,
|
||||||
|
func_kwargs=self.func_kwargs,
|
||||||
|
symbols=self.symbols - syms,
|
||||||
|
realized_symbols=self.realized_symbols | symbol_values,
|
||||||
|
)
|
||||||
|
msg = f'ParamsFlow: Not all partially realized symbols are defined on the ParamsFlow (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Generate ExprSocketDef
|
# - Generate ExprSocketDef
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
"""Declares `ManagedBLImage`."""
|
"""Declares `ManagedBLImage`."""
|
||||||
|
|
||||||
# import time
|
import time
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj):
|
||||||
dpi: int | None = None,
|
dpi: int | None = None,
|
||||||
bl_select: bool = False,
|
bl_select: bool = False,
|
||||||
):
|
):
|
||||||
# times = [time.perf_counter()]
|
times = ['START', time.perf_counter()]
|
||||||
|
|
||||||
# Compute Plot Dimensions
|
# Compute Plot Dimensions
|
||||||
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||||
|
@ -277,22 +277,22 @@ class ManagedBLImage(base.ManagedObj):
|
||||||
# _width_inches, _height_inches, _dpi
|
# _width_inches, _height_inches, _dpi
|
||||||
# )
|
# )
|
||||||
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
|
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
|
||||||
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
# fig.clear()
|
# fig.clear()
|
||||||
ax.clear()
|
ax.clear()
|
||||||
# times.append(['Clear Axis', time.perf_counter() - times[0]])
|
times.append(['Clear Axis', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
# Plot w/User Parameter
|
# Plot w/User Parameter
|
||||||
func_plotter(ax)
|
func_plotter(ax)
|
||||||
# times.append(['Plot!', time.perf_counter() - times[0]])
|
times.append(['Plot!', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
# Save Figure to BytesIO
|
# Save Figure to BytesIO
|
||||||
canvas.draw()
|
canvas.draw()
|
||||||
# times.append(['Draw Pixels', time.perf_counter() - times[0]])
|
times.append(['Draw Pixels', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
canvas_width_px, canvas_height_px = fig.canvas.get_width_height()
|
canvas_width_px, canvas_height_px = fig.canvas.get_width_height()
|
||||||
# times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
|
times.append(['Get Canvas Dims', time.perf_counter() - times[0]])
|
||||||
image_data = (
|
image_data = (
|
||||||
np.float32(
|
np.float32(
|
||||||
np.flipud(
|
np.flipud(
|
||||||
|
@ -303,7 +303,7 @@ class ManagedBLImage(base.ManagedObj):
|
||||||
)
|
)
|
||||||
/ 255
|
/ 255
|
||||||
)
|
)
|
||||||
# times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
|
times.append(['Load Data from Canvas', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
# Optimized Write to Blender Image
|
# Optimized Write to Blender Image
|
||||||
bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8')
|
bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8')
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import queue
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
@ -26,6 +28,14 @@ from .managed_objs.managed_bl_image import ManagedBLImage
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
link_action_queue = queue.Queue()
|
||||||
|
|
||||||
|
|
||||||
|
def set_link_validity(link: bpy.types.NodeLink, validity: bool) -> None:
|
||||||
|
log.critical('Set %s validity to %s', str(link), str(validity))
|
||||||
|
link.is_valid = validity
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Cache Management
|
# - Cache Management
|
||||||
####################
|
####################
|
||||||
|
@ -45,58 +55,93 @@ class DeltaNodeLinkCache(typ.TypedDict):
|
||||||
|
|
||||||
|
|
||||||
class NodeLinkCache:
|
class NodeLinkCache:
|
||||||
"""A pointer-based cache of node links in a node tree.
|
"""A volatile pointer-based cache of node links in a node tree.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
Everything here is **extremely** unsafe.
|
||||||
|
Even a single mistake **will** cause a use-after-free crash of Blender.
|
||||||
|
|
||||||
|
Used perfectly, it allows for powerful features; anything less, and it's an epic liability.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
_node_tree: Reference to the owning node tree.
|
_node_tree: Reference to the node tree for which this cache is valid.
|
||||||
link_ptrs_as_links:
|
link_ptrs: Memory-address identifiers for all node links that currently exist in `_node_tree`.
|
||||||
link_ptrs: Pointers (as in integer memory adresses) to `NodeLink`s.
|
link_ptrs_as_links: Mapping from pointers (integers) to actual `NodeLink` objects.
|
||||||
link_ptrs_as_links: Map from pointers to actual `NodeLink`s.
|
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
|
||||||
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the source of each `NodeLink`.
|
socket_ptrs: Memory-address identifiers for all sockets that currently exist in `_node_tree`.
|
||||||
link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the destination of each `NodeLink`.
|
socket_ptrs_as_sockets: Mapping from pointers (integers) to actual `NodeSocket` objects.
|
||||||
|
**WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this.
|
||||||
|
socket_ptr_refcount: The amount of links currently connected to a given socket pointer.
|
||||||
|
Used to drive the deletion of socket pointers using only knowledge about `link_ptr` removal.
|
||||||
|
link_ptrs_as_from_socket_ptrs: The pointer of the source socket, defined for every node link pointer.
|
||||||
|
link_ptrs_as_to_socket_ptrs: The pointer of the destination socket, defined for every node link pointer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_tree: bpy.types.NodeTree):
|
def __init__(self, node_tree: bpy.types.NodeTree):
|
||||||
"""Initialize the cache from a node tree.
|
"""Defines and fills the cache from a live node tree."""
|
||||||
|
|
||||||
Parameters:
|
|
||||||
node_tree: The Blender node tree whose `NodeLink`s will be cached.
|
|
||||||
"""
|
|
||||||
self._node_tree = node_tree
|
self._node_tree = node_tree
|
||||||
|
|
||||||
# Link PTR and PTR->REF
|
|
||||||
self.link_ptrs: set[MemAddr] = set()
|
self.link_ptrs: set[MemAddr] = set()
|
||||||
self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {}
|
self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {}
|
||||||
|
|
||||||
# Socket PTR and PTR->REF
|
|
||||||
self.socket_ptrs: set[MemAddr] = set()
|
self.socket_ptrs: set[MemAddr] = set()
|
||||||
self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {}
|
self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {}
|
||||||
self.socket_ptr_refcount: dict[MemAddr, int] = {}
|
self.socket_ptr_refcount: dict[MemAddr, int] = {}
|
||||||
|
|
||||||
# Link PTR -> Socket PTR
|
|
||||||
self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
||||||
self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {}
|
||||||
|
|
||||||
|
self.link_ptrs_invalid: set[MemAddr] = set()
|
||||||
|
|
||||||
# Fill Cache
|
# Fill Cache
|
||||||
self.regenerate()
|
self.regenerate()
|
||||||
|
|
||||||
def remove_link(self, link_ptr: MemAddr) -> None:
|
def remove_link(self, link_ptr: MemAddr) -> None:
|
||||||
"""Removes a link pointer from the cache, indicating that the link doesn't exist anymore.
|
"""Reports a link as removed, causing it to be removed from the cache.
|
||||||
|
|
||||||
|
This **must** be run whenever a node link is deleted.
|
||||||
|
**Failure to to so WILL result in segmentation fault** at an unknown future time.
|
||||||
|
|
||||||
|
In particular, the following actions are taken:
|
||||||
|
- The entry in `self.link_ptrs_as_links` is deleted.
|
||||||
|
- Any entry in `self.link_ptrs_invalid` is deleted (if exists).
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- **DOES NOT** remove PTR->REF dictionary entries
|
Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
|
||||||
- Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`.
|
In some cases, this may be desirable, ex. for internal methods that shouldn't trip a `DataChanged` flow event.
|
||||||
- This **must** be done whenever a node link is deleted.
|
|
||||||
- Failure to do so may result in a segmentation fault at arbitrary future time.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
link_ptr: Pointer to remove from the cache.
|
link_ptr: The pointer (integer) to remove from the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If `link_ptr` is not a member of either `self.link_ptrs`, or of `self.link_ptrs_as_links`.
|
||||||
"""
|
"""
|
||||||
self.link_ptrs.remove(link_ptr)
|
self.link_ptrs.remove(link_ptr)
|
||||||
self.link_ptrs_as_links.pop(link_ptr)
|
self.link_ptrs_as_links.pop(link_ptr)
|
||||||
|
|
||||||
|
if link_ptr in self.link_ptrs_invalid:
|
||||||
|
self.link_ptrs_invalid.remove(link_ptr)
|
||||||
|
|
||||||
def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None:
|
def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None:
|
||||||
"""Removes a single pointer's reference to its from/to sockets."""
|
"""Deassociate from all sockets referenced by a link, respecting the socket pointer reference-count.
|
||||||
|
|
||||||
|
The `NodeLinkCache` stores references to all socket pointers referenced by any link.
|
||||||
|
Since several links can be associated with each socket, we must keep a "reference count" per-socket.
|
||||||
|
When the "reference count" drops to zero, then there are no longer any `NodeLink`s that refer to it, and therefore it should be removed from the `NodeLinkCache`.
|
||||||
|
|
||||||
|
This method facilitates that process by:
|
||||||
|
- Extracting (with removal) the from / to socket pointers associated with `link_ptr`.
|
||||||
|
- If the socket pointer has a reference count of `1`, then it is **completely removed**.
|
||||||
|
- If the socket pointer has a reference count of `>1`, then the reference count is decremented by `1`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
In general, this should be called together with `remove_link`.
|
||||||
|
However, in certain cases, this process also needs to happen by itself.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
link_ptr: The pointer (integer) to remove from the cache.
|
||||||
|
"""
|
||||||
|
# Remove Socket Pointers
|
||||||
from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None)
|
from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None)
|
||||||
to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None)
|
to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None)
|
||||||
|
|
||||||
|
@ -113,31 +158,40 @@ class NodeLinkCache:
|
||||||
self.socket_ptr_refcount[socket_ptr] -= 1
|
self.socket_ptr_refcount[socket_ptr] -= 1
|
||||||
|
|
||||||
def regenerate(self) -> DeltaNodeLinkCache:
|
def regenerate(self) -> DeltaNodeLinkCache:
|
||||||
"""Regenerates the cache from the internally-linked node tree.
|
"""Efficiently scans the internally referenced node tree to thoroughly update all attributes of this `NodeLinkCache`.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This is designed to run within the `update()` invocation of the node tree.
|
This runs in a **very** hot loop, within the `update()` function of the node tree.
|
||||||
- This should be a very fast function, since it is called so much.
|
Anytime anything happens in the node tree, `update()` (and therefore this method) is called.
|
||||||
|
|
||||||
|
Thus, performance is of the utmost importance.
|
||||||
|
Just a few microseconds too much may be amplified dozens of times over in practice, causing big stutters.
|
||||||
"""
|
"""
|
||||||
# Compute All NodeLink Pointers
|
# Compute All NodeLink Pointers
|
||||||
|
## -> It can be very inefficient to do any full-scan of the node tree.
|
||||||
|
## -> However, simply extracting the pointer: link ends up being fast.
|
||||||
|
## -> This pattern seems to be the best we can do, efficiency-wise.
|
||||||
all_link_ptrs_as_links = {
|
all_link_ptrs_as_links = {
|
||||||
link.as_pointer(): link for link in self._node_tree.links
|
link.as_pointer(): link for link in self._node_tree.links
|
||||||
}
|
}
|
||||||
all_link_ptrs = set(all_link_ptrs_as_links.keys())
|
all_link_ptrs = set(all_link_ptrs_as_links.keys())
|
||||||
|
|
||||||
# Compute Added/Removed Links
|
# Compute Added/Removed Links
|
||||||
|
## -> In essence, we've created a 'diff' here.
|
||||||
|
## -> Set operations are fast, and expressive!
|
||||||
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
||||||
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
||||||
|
|
||||||
# Edge Case: 'from_socket' Reassignment
|
# Edge Case: 'from_socket' Reassignment
|
||||||
## (Reverse engineered) When all:
|
## (Reverse Engineered) When all are true:
|
||||||
## - Created a new link between the same two nodes.
|
## - Created a new link between the same nodes as previous link.
|
||||||
## - Matching 'to_socket'.
|
## - Matching 'to_socket' as the previous link.
|
||||||
## - Non-matching 'from_socket' on the same node.
|
## - Non-matching 'from_socket', but on the same node.
|
||||||
## -> THEN the link_ptr will not change, but the from_socket ptr should.
|
## -> THEN the link_ptr will not change, but the from_socket ptr does.
|
||||||
if len(added_link_ptrs) == 0 and len(removed_link_ptrs) == 0:
|
if not added_link_ptrs and not removed_link_ptrs:
|
||||||
# Find the Link w/Reassigned 'from_socket' PTR
|
# Find the Link w/Reassigned 'from_socket' PTR
|
||||||
## A bit of a performance hit from the search, but it's an edge case.
|
## -> This isn't very fast, but the edge case isn't so common.
|
||||||
|
## -> Comprehensions are still quite optimized.
|
||||||
_link_ptr_as_from_socket_ptrs = {
|
_link_ptr_as_from_socket_ptrs = {
|
||||||
link_ptr: (
|
link_ptr: (
|
||||||
from_socket_ptr,
|
from_socket_ptr,
|
||||||
|
@ -149,9 +203,9 @@ class NodeLinkCache:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Completely Remove the Old Link (w/Reassigned 'from_socket')
|
# Completely Remove the Old Link (w/Reassigned 'from_socket')
|
||||||
## This effectively reclassifies the edge case as a normal 're-add'.
|
## -> Casts the edge case to look like a typical 're-add'.
|
||||||
for link_ptr in _link_ptr_as_from_socket_ptrs:
|
for link_ptr in _link_ptr_as_from_socket_ptrs:
|
||||||
log.info(
|
log.debug(
|
||||||
'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s',
|
'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s',
|
||||||
link_ptr,
|
link_ptr,
|
||||||
)
|
)
|
||||||
|
@ -159,21 +213,25 @@ class NodeLinkCache:
|
||||||
self.remove_sockets_by_link_ptr(link_ptr)
|
self.remove_sockets_by_link_ptr(link_ptr)
|
||||||
|
|
||||||
# Recompute Added/Removed Links
|
# Recompute Added/Removed Links
|
||||||
## The algorithm will now detect an "added link".
|
## -> Guide the usual algorithm to detect an "added link".
|
||||||
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
added_link_ptrs = all_link_ptrs - self.link_ptrs
|
||||||
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
removed_link_ptrs = self.link_ptrs - all_link_ptrs
|
||||||
|
|
||||||
# Shuffle Cache based on Change in Links
|
# Delete Removed Links
|
||||||
## Remove Entries for Removed Pointers
|
## -> NOTE: We leave dangling socket information on purpose.
|
||||||
|
## -> This information will be used to ask for 'removal consent'.
|
||||||
|
## -> To truly remove, must call 'remove_socket_by_link_ptr' later.
|
||||||
for removed_link_ptr in removed_link_ptrs:
|
for removed_link_ptr in removed_link_ptrs:
|
||||||
self.remove_link(removed_link_ptr)
|
self.remove_link(removed_link_ptr)
|
||||||
## User must manually call 'remove_socket_by_link_ptr' later.
|
|
||||||
## For now, leave dangling socket information by-link.
|
|
||||||
|
|
||||||
# Add New Link Pointers
|
# Create Added Links
|
||||||
|
## -> First, simply concatenate the added link pointers.
|
||||||
self.link_ptrs |= added_link_ptrs
|
self.link_ptrs |= added_link_ptrs
|
||||||
for link_ptr in added_link_ptrs:
|
for link_ptr in added_link_ptrs:
|
||||||
# Add Link PTR->REF
|
# Create Pointer -> Reference Entry
|
||||||
|
## -> This allows us to efficiently access the link by-pointer.
|
||||||
|
## -> Doing so otherwise requires a full search.
|
||||||
|
## -> **If link is deleted w/o report, access will cause crash**.
|
||||||
new_link = all_link_ptrs_as_links[link_ptr]
|
new_link = all_link_ptrs_as_links[link_ptr]
|
||||||
self.link_ptrs_as_links[link_ptr] = new_link
|
self.link_ptrs_as_links[link_ptr] = new_link
|
||||||
|
|
||||||
|
@ -183,34 +241,69 @@ class NodeLinkCache:
|
||||||
to_socket = new_link.to_socket
|
to_socket = new_link.to_socket
|
||||||
to_socket_ptr = to_socket.as_pointer()
|
to_socket_ptr = to_socket.as_pointer()
|
||||||
|
|
||||||
# Add Socket PTR, PTR -> REF
|
# Add Socket Information
|
||||||
for socket_ptr, bl_socket in zip( # noqa: B905
|
for socket_ptr, bl_socket in zip( # noqa: B905
|
||||||
[from_socket_ptr, to_socket_ptr],
|
[from_socket_ptr, to_socket_ptr],
|
||||||
[from_socket, to_socket],
|
[from_socket, to_socket],
|
||||||
):
|
):
|
||||||
# Increment RefCount of Socket PTR
|
# RefCount > 0: Increment RefCount of Socket PTR
|
||||||
## This happens if another link also uses the same socket.
|
## This happens if another link also uses the same socket.
|
||||||
## 1. An output socket links to several inputs.
|
## 1. An output socket links to several inputs.
|
||||||
## 2. A multi-input socket links from several inputs.
|
## 2. A multi-input socket links from several inputs.
|
||||||
if socket_ptr in self.socket_ptr_refcount:
|
if socket_ptr in self.socket_ptr_refcount:
|
||||||
self.socket_ptr_refcount[socket_ptr] += 1
|
self.socket_ptr_refcount[socket_ptr] += 1
|
||||||
|
|
||||||
|
# RefCount == 0: Create Socket Pointer w/Reference
|
||||||
|
## -> Also initialize the refcount for the socket pointer.
|
||||||
else:
|
else:
|
||||||
## RefCount == 0: Add PTR, PTR -> REF
|
|
||||||
self.socket_ptrs.add(socket_ptr)
|
self.socket_ptrs.add(socket_ptr)
|
||||||
self.socket_ptrs_as_sockets[socket_ptr] = bl_socket
|
self.socket_ptrs_as_sockets[socket_ptr] = bl_socket
|
||||||
self.socket_ptr_refcount[socket_ptr] = 1
|
self.socket_ptr_refcount[socket_ptr] = 1
|
||||||
|
|
||||||
# Add Link PTR -> Socket PTR
|
# Add Entry from Link Pointer -> Socket Pointer
|
||||||
self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr
|
self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr
|
||||||
self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr
|
self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr
|
||||||
|
|
||||||
return {'added': added_link_ptrs, 'removed': removed_link_ptrs}
|
return {'added': added_link_ptrs, 'removed': removed_link_ptrs}
|
||||||
|
|
||||||
|
def update_validity(self) -> DeltaNodeLinkCache:
|
||||||
|
"""Query all cached links to determine whether they are valid."""
|
||||||
|
self.link_ptrs_invalid = {
|
||||||
|
link_ptr for link_ptr, link in self.link_ptrs_as_links if not link.is_valid
|
||||||
|
}
|
||||||
|
|
||||||
|
def report_validity(self, link_ptr: MemAddr, validity: bool) -> None:
|
||||||
|
"""Report a link as invalid."""
|
||||||
|
if validity and link_ptr in self.link_ptrs_invalid:
|
||||||
|
self.link_ptrs_invalid.remove(link_ptr)
|
||||||
|
elif not validity and link_ptr not in self.link_ptrs_invalid:
|
||||||
|
self.link_ptrs_invalid.add(link_ptr)
|
||||||
|
|
||||||
|
def set_validities(self) -> None:
|
||||||
|
"""Set the validity of links in the node tree according to the internal cache.
|
||||||
|
|
||||||
|
Validity doesn't need to be removed, as update() automatically cleans up by default.
|
||||||
|
"""
|
||||||
|
for link in [
|
||||||
|
link
|
||||||
|
for link_ptr, link in self.link_ptrs_as_links.items()
|
||||||
|
if link_ptr in self.link_ptrs_invalid
|
||||||
|
]:
|
||||||
|
if link.is_valid:
|
||||||
|
link.is_valid = False
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Node Tree Definition
|
# - Node Tree Definition
|
||||||
####################
|
####################
|
||||||
class MaxwellSimTree(bpy.types.NodeTree):
|
class MaxwellSimTree(bpy.types.NodeTree):
|
||||||
|
"""Node tree containing a node-based program for design and analysis of Maxwell PDE simulations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
is_active: Whether the node tree should be considered to be in a usable state, capable of updating Blender data.
|
||||||
|
In general, only one `MaxwellSimTree` should be active at a time.
|
||||||
|
"""
|
||||||
|
|
||||||
bl_idname = ct.TreeType.MaxwellSim.value
|
bl_idname = ct.TreeType.MaxwellSim.value
|
||||||
bl_label = 'Maxwell Sim Editor'
|
bl_label = 'Maxwell Sim Editor'
|
||||||
bl_icon = ct.Icon.SimNodeEditor
|
bl_icon = ct.Icon.SimNodeEditor
|
||||||
|
@ -219,63 +312,6 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
|
||||||
# - Lock Methods
|
|
||||||
####################
|
|
||||||
def unlock_all(self) -> None:
|
|
||||||
"""Unlock all nodes in the node tree, making them editable."""
|
|
||||||
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
|
|
||||||
for node in self.nodes:
|
|
||||||
if node.type in ['REROUTE', 'FRAME']:
|
|
||||||
continue
|
|
||||||
node.locked = False
|
|
||||||
for bl_socket in [*node.inputs, *node.outputs]:
|
|
||||||
bl_socket.locked = False
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def replot(self) -> None:
|
|
||||||
self.is_currently_replotting = True
|
|
||||||
self.something_plotted = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self.is_currently_replotting = False
|
|
||||||
if not self.something_plotted:
|
|
||||||
ManagedBLImage.hide_preview()
|
|
||||||
|
|
||||||
def report_show_plot(self, node: bpy.types.Node) -> None:
|
|
||||||
if hasattr(self, 'is_currently_replotting') and self.is_currently_replotting:
|
|
||||||
self.something_plotted = True
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def repreview_all(self) -> None:
|
|
||||||
all_nodes_with_preview_active = {
|
|
||||||
node.instance_id: node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.type not in ['REROUTE', 'FRAME'] and node.preview_active
|
|
||||||
}
|
|
||||||
self.is_currently_repreviewing = True
|
|
||||||
self.newly_previewed_nodes = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self.is_currently_repreviewing = False
|
|
||||||
for dangling_previewed_node in [
|
|
||||||
node
|
|
||||||
for node_instance_id, node in all_nodes_with_preview_active.items()
|
|
||||||
if node_instance_id not in self.newly_previewed_nodes
|
|
||||||
]:
|
|
||||||
dangling_previewed_node.preview_active = False
|
|
||||||
|
|
||||||
def report_show_preview(self, node: bpy.types.Node) -> None:
|
|
||||||
if (
|
|
||||||
hasattr(self, 'is_currently_repreviewing')
|
|
||||||
and self.is_currently_repreviewing
|
|
||||||
):
|
|
||||||
self.newly_previewed_nodes[node.instance_id] = node
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Init Methods
|
# - Init Methods
|
||||||
####################
|
####################
|
||||||
|
@ -290,7 +326,54 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
||||||
self.node_link_cache = NodeLinkCache(self)
|
self.node_link_cache = NodeLinkCache(self)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Update Methods
|
# - Lock Methods
|
||||||
|
####################
|
||||||
|
def unlock_all(self) -> None:
|
||||||
|
"""Unlock all nodes in the node tree, making them editable.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
All `MaxwellSimNode`s have a `.locked` attribute, which prevents the entire UI from being modified.
|
||||||
|
|
||||||
|
This method simply sets the `locked` attribute to `False` on all nodes.
|
||||||
|
"""
|
||||||
|
log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label)
|
||||||
|
for node in self.nodes:
|
||||||
|
if node.type in ['REROUTE', 'FRAME']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Unlock Node
|
||||||
|
if node.locked:
|
||||||
|
node.locked = False
|
||||||
|
|
||||||
|
# Unlock Node Sockets
|
||||||
|
for bl_socket in [*node.inputs, *node.outputs]:
|
||||||
|
if bl_socket.locked:
|
||||||
|
bl_socket.locked = False
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Link Update Methods
|
||||||
|
####################
|
||||||
|
def report_link_validity(self, link: bpy.types.NodeLink, validity: bool) -> None:
|
||||||
|
"""Report that a particular `NodeLink` should be considered to be either valid or invalid.
|
||||||
|
|
||||||
|
The `NodeLink.is_valid` attribute is generally (and automatically) used to indicate the detection of cycles in the node tree.
|
||||||
|
However, visually, it causes a very clear "error red" highlight to appear on the node link, which can extremely useful when determining the reasons behind unexpected outout.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Run by `MaxwellSimSocket` when a link should be shown to be "invalid".
|
||||||
|
"""
|
||||||
|
## TODO: Doesn't quite work.
|
||||||
|
# log.debug(
|
||||||
|
# 'Reported Link Validity %s (is_valid=%s, from_socket=%s, to_socket=%s)',
|
||||||
|
# validity,
|
||||||
|
# link.is_valid,
|
||||||
|
# link.from_socket,
|
||||||
|
# link.to_socket,
|
||||||
|
# )
|
||||||
|
# self.node_link_cache.report_validity(link.as_pointer(), validity)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Node Update Methods
|
||||||
####################
|
####################
|
||||||
def on_node_removed(self, node: bpy.types.Node):
|
def on_node_removed(self, node: bpy.types.Node):
|
||||||
"""Run by `MaxwellSimNode.free()` when a node is being removed.
|
"""Run by `MaxwellSimNode.free()` when a node is being removed.
|
||||||
|
@ -327,32 +410,36 @@ class MaxwellSimTree(bpy.types.NodeTree):
|
||||||
self.node_link_cache.remove_link(link_ptr)
|
self.node_link_cache.remove_link(link_ptr)
|
||||||
self.node_link_cache.remove_sockets_by_link_ptr(link_ptr)
|
self.node_link_cache.remove_sockets_by_link_ptr(link_ptr)
|
||||||
|
|
||||||
def update(self) -> None:
|
def update(self) -> None: # noqa: PLR0912, C901
|
||||||
"""Monitors all changes to the node tree, potentially responding with appropriate callbacks.
|
"""Monitors all changes to the node tree, potentially responding with appropriate callbacks.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- Run by Blender when "anything" changes in the node tree.
|
- Run by Blender when "anything" changes in the node tree.
|
||||||
- Responds to node link changes with callbacks, with the help of a performant node link cache.
|
- Responds to node link changes with callbacks, with the help of a performant node link cache.
|
||||||
"""
|
"""
|
||||||
|
# Perform Initial Load
|
||||||
|
## -> Presume update() is run before the first link is altered.
|
||||||
|
## -> Else, the first link of the session will not update caches.
|
||||||
|
## -> We still remain slightly unsure of the exact semantics.
|
||||||
|
## -> Therefore, self.on_load() is also called as a load_post handler.
|
||||||
|
if not hasattr(self, 'node_link_cache'):
|
||||||
|
self.on_load()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Register Validity Updater
|
||||||
|
## -> They will be run after the update() method.
|
||||||
|
## -> Between update() and set_validities, all is_valid=True are cleared.
|
||||||
|
## -> Therefore, 'set_validities' only needs to set all is_valid=False.
|
||||||
|
bpy.app.timers.register(self.node_link_cache.set_validities)
|
||||||
|
|
||||||
|
# Ignore Updates
|
||||||
|
## -> Certain corrective processes require suppressing the next update.
|
||||||
|
## -> Otherwise, link corrections may trigger some nasty recursions.
|
||||||
if not hasattr(self, 'ignore_update'):
|
if not hasattr(self, 'ignore_update'):
|
||||||
self.ignore_update = False
|
self.ignore_update = False
|
||||||
|
|
||||||
if not hasattr(self, 'node_link_cache'):
|
# Regenerate NodeLinkCache
|
||||||
self.on_load()
|
|
||||||
## We presume update() is run before the first link is altered.
|
|
||||||
## - Else, the first link of the session will not update caches.
|
|
||||||
## - We remain slightly unsure of the semantics.
|
|
||||||
## - Therefore, self.on_load() is also called as a load_post handler.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Ignore Update
|
|
||||||
## Manually set to implement link corrections w/o recursion.
|
|
||||||
if self.ignore_update:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Compute Changes to Node Links
|
|
||||||
delta_links = self.node_link_cache.regenerate()
|
delta_links = self.node_link_cache.regenerate()
|
||||||
|
|
||||||
link_corrections = {
|
link_corrections = {
|
||||||
'to_remove': [],
|
'to_remove': [],
|
||||||
'to_add': [],
|
'to_add': [],
|
||||||
|
|
|
@ -358,6 +358,11 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
|
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
|
||||||
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
|
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
|
||||||
|
|
||||||
|
# Extract Info
|
||||||
|
## -> We only need the output symbol.
|
||||||
|
## -> All labelled outputs have the same output SimSymbol.
|
||||||
|
info = extract_info(monitor_data, idx_labels[0])
|
||||||
|
|
||||||
# Generate FuncFlow Per Index Label
|
# Generate FuncFlow Per Index Label
|
||||||
## -> We extract each XArray as an attribute of monitor_data.
|
## -> We extract each XArray as an attribute of monitor_data.
|
||||||
## -> We then bind its values into a unique func_flow.
|
## -> We then bind its values into a unique func_flow.
|
||||||
|
@ -377,7 +382,8 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
## -> Then, 'compose_within' lets us stack them along axis=0.
|
## -> Then, 'compose_within' lets us stack them along axis=0.
|
||||||
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
|
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
|
||||||
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
|
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
|
||||||
enclosing_func=lambda data: jnp.stack(data, axis=0)
|
lambda data: jnp.stack(data, axis=0),
|
||||||
|
func_output=info.output,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
|
@ -65,12 +65,12 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
return {
|
return {
|
||||||
# Slice
|
# Slice
|
||||||
FO.Slice: '=a[i:j]',
|
FO.Slice: '≈a[v₁:v₂]',
|
||||||
FO.SliceIdx: '≈a[v₁:v₂]',
|
FO.SliceIdx: '=a[i:j]',
|
||||||
# Pin
|
# Pin
|
||||||
FO.PinLen1: 'pinₐ',
|
FO.PinLen1: 'a[0] → a',
|
||||||
FO.Pin: 'pinₐ ≈v',
|
FO.Pin: 'a[v] ⇝ a',
|
||||||
FO.PinIdx: 'pinₐ =i',
|
FO.PinIdx: 'a[i] → a',
|
||||||
# Reinterpret
|
# Reinterpret
|
||||||
FO.Swap: 'a₁ ↔ a₂',
|
FO.Swap: 'a₁ ↔ a₂',
|
||||||
}[value]
|
}[value]
|
||||||
|
@ -517,6 +517,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
return lazy_func.compose_within(
|
return lazy_func.compose_within(
|
||||||
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
|
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
|
||||||
enclosing_func_args=operation.func_args,
|
enclosing_func_args=operation.func_args,
|
||||||
|
enclosing_func_output=info.output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
|
@ -547,22 +547,30 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
|
# Loaded
|
||||||
kind=ct.FlowKind.Func,
|
kind=ct.FlowKind.Func,
|
||||||
props={'operation'},
|
props={'operation'},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr': ct.FlowKind.Func,
|
'Expr': ct.FlowKind.Func,
|
||||||
},
|
},
|
||||||
|
output_sockets={'Expr'},
|
||||||
|
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
def compute_func(
|
||||||
operation = props['operation']
|
self, props, input_sockets, output_sockets
|
||||||
|
) -> ct.FuncFlow | ct.FlowSignal:
|
||||||
expr = input_sockets['Expr']
|
expr = input_sockets['Expr']
|
||||||
|
output_info = output_sockets['Expr']
|
||||||
|
|
||||||
has_expr = not ct.FlowSignal.check(expr)
|
has_expr = not ct.FlowSignal.check(expr)
|
||||||
|
has_output_info = not ct.FlowSignal.check(output_info)
|
||||||
|
|
||||||
|
operation = props['operation']
|
||||||
if has_expr and operation is not None:
|
if has_expr and operation is not None:
|
||||||
return expr.compose_within(
|
return expr.compose_within(
|
||||||
operation.jax_func,
|
operation.jax_func,
|
||||||
|
enclosing_func_output=output_info.output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
|
@ -146,7 +146,6 @@ class BinaryOperation(enum.StrEnum):
|
||||||
outl = info_l.output
|
outl = info_l.output
|
||||||
outr = info_r.output
|
outr = info_r.output
|
||||||
match (outl.shape_len, outr.shape_len):
|
match (outl.shape_len, outr.shape_len):
|
||||||
# match (ol.shape_len, info_r.output.shape_len):
|
|
||||||
# Number | *
|
# Number | *
|
||||||
## Number | Number
|
## Number | Number
|
||||||
case (0, 0):
|
case (0, 0):
|
||||||
|
@ -154,15 +153,25 @@ class BinaryOperation(enum.StrEnum):
|
||||||
BO.Add,
|
BO.Add,
|
||||||
BO.Sub,
|
BO.Sub,
|
||||||
BO.Mul,
|
BO.Mul,
|
||||||
BO.Div,
|
|
||||||
BO.Pow,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Check Non-Zero Right Hand Side
|
||||||
|
## -> Obviously, we can't ever divide by zero.
|
||||||
|
## -> Sympy's assumptions system must always guarantee rhs != 0.
|
||||||
|
## -> If it can't, then we simply don't expose division.
|
||||||
|
## -> The is_zero assumption must be provided elsewhere.
|
||||||
|
## -> NOTE: This may prevent some valid uses of division.
|
||||||
|
## -> Watch out for "division is missing" bugs.
|
||||||
|
if info_r.output.is_nonzero:
|
||||||
|
ops.append(BO.Div)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
info_l.output.physical_type == spux.PhysicalType.Length
|
info_l.output.physical_type == spux.PhysicalType.Length
|
||||||
and info_l.output.unit == info_r.output.unit
|
and info_l.output.unit == info_r.output.unit
|
||||||
):
|
):
|
||||||
ops += [BO.Atan2]
|
ops += [BO.Atan2]
|
||||||
return ops
|
|
||||||
|
return [*ops, BO.Pow]
|
||||||
|
|
||||||
## Number | Vector
|
## Number | Vector
|
||||||
case (0, 1):
|
case (0, 1):
|
||||||
|
@ -336,7 +345,13 @@ class BinaryOperation(enum.StrEnum):
|
||||||
# - InfoFlow Transform
|
# - InfoFlow Transform
|
||||||
####################
|
####################
|
||||||
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
||||||
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
|
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
|
||||||
|
|
||||||
|
If not, bad things will happen.
|
||||||
|
"""
|
||||||
return info_l.operate_output(
|
return info_l.operate_output(
|
||||||
info_r,
|
info_r,
|
||||||
lambda a, b: self.sp_func([a, b]),
|
lambda a, b: self.sp_func([a, b]),
|
||||||
|
@ -479,29 +494,35 @@ class OperateMathNode(base.MaxwellSimNode):
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Func,
|
kind=ct.FlowKind.Func,
|
||||||
|
# Loaded
|
||||||
props={'operation'},
|
props={'operation'},
|
||||||
input_sockets={'Expr L', 'Expr R'},
|
input_sockets={'Expr L', 'Expr R'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr L': ct.FlowKind.Func,
|
'Expr L': ct.FlowKind.Func,
|
||||||
'Expr R': ct.FlowKind.Func,
|
'Expr R': ct.FlowKind.Func,
|
||||||
},
|
},
|
||||||
|
output_sockets={'Expr'},
|
||||||
|
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
def compose_func(self, props: dict, input_sockets: dict):
|
def compute_func(self, props, input_sockets, output_sockets):
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
if operation is None:
|
if operation is None:
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
expr_l = input_sockets['Expr L']
|
expr_l = input_sockets['Expr L']
|
||||||
expr_r = input_sockets['Expr R']
|
expr_r = input_sockets['Expr R']
|
||||||
|
output_info = output_sockets['Expr']
|
||||||
|
|
||||||
has_expr_l = not ct.FlowSignal.check(expr_l)
|
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||||
has_expr_r = not ct.FlowSignal.check(expr_r)
|
has_expr_r = not ct.FlowSignal.check(expr_r)
|
||||||
|
has_output_info = not ct.FlowSignal.check(output_info)
|
||||||
|
|
||||||
# Compute Jax Function
|
# Compute Jax Function
|
||||||
## -> The operation enum directly provides the appropriate function.
|
## -> The operation enum directly provides the appropriate function.
|
||||||
if has_expr_l and has_expr_r:
|
if has_expr_l and has_expr_r and has_output_info:
|
||||||
return (expr_l | expr_r).compose_within(
|
return (expr_l | expr_r).compose_within(
|
||||||
enclosing_func=operation.jax_func,
|
operation.jax_func,
|
||||||
|
enclosing_func_output=output_info.output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
@ -520,6 +541,8 @@ class OperateMathNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
|
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
|
||||||
|
BO = BinaryOperation
|
||||||
|
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
info_l = input_sockets['Expr L']
|
info_l = input_sockets['Expr L']
|
||||||
info_r = input_sockets['Expr R']
|
info_r = input_sockets['Expr R']
|
||||||
|
@ -533,7 +556,7 @@ class OperateMathNode(base.MaxwellSimNode):
|
||||||
has_info_l
|
has_info_l
|
||||||
and has_info_r
|
and has_info_r
|
||||||
and operation is not None
|
and operation is not None
|
||||||
and operation in BinaryOperation.by_infos(info_l, info_r)
|
and operation in BO.by_infos(info_l, info_r)
|
||||||
):
|
):
|
||||||
return operation.transform_infos(info_l, info_r)
|
return operation.transform_infos(info_l, info_r)
|
||||||
|
|
||||||
|
|
|
@ -606,27 +606,32 @@ class TransformMathNode(base.MaxwellSimNode):
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
|
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
|
||||||
},
|
},
|
||||||
|
output_sockets={'Expr'},
|
||||||
|
output_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
def compute_func(
|
||||||
|
self, props, input_sockets, output_sockets
|
||||||
|
) -> ct.FuncFlow | ct.FlowSignal:
|
||||||
"""Transform the input `InfoFlow` depending on the transform operation."""
|
"""Transform the input `InfoFlow` depending on the transform operation."""
|
||||||
TO = TransformOperation
|
TO = TransformOperation
|
||||||
operation = props['operation']
|
|
||||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||||
|
output_info = output_sockets['Expr']
|
||||||
|
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||||
|
has_output_info = not ct.FlowSignal.check(output_info)
|
||||||
|
|
||||||
if operation is not None and has_lazy_func and has_info:
|
operation = props['operation']
|
||||||
# Retrieve Properties
|
if operation is not None and has_lazy_func and has_info and has_output_info:
|
||||||
dim = props['dim']
|
dim = props['dim']
|
||||||
|
|
||||||
# Match Pattern by Operation
|
|
||||||
match operation:
|
match operation:
|
||||||
case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D:
|
case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D:
|
||||||
if dim is not None and info.has_idx_discrete(dim):
|
if dim is not None and info.has_idx_discrete(dim):
|
||||||
return lazy_func.compose_within(
|
return lazy_func.compose_within(
|
||||||
operation.jax_func(axis=info.dim_axis(dim)),
|
operation.jax_func(axis=info.dim_axis(dim)),
|
||||||
|
enclosing_func_output=output_info.output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
@ -634,6 +639,7 @@ class TransformMathNode(base.MaxwellSimNode):
|
||||||
case _:
|
case _:
|
||||||
return lazy_func.compose_within(
|
return lazy_func.compose_within(
|
||||||
operation.jax_func(),
|
operation.jax_func(),
|
||||||
|
enclosing_func_output=output_info.output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -406,7 +406,7 @@ class VizNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
all_loose_input_sockets=True,
|
all_loose_input_sockets=True,
|
||||||
)
|
)
|
||||||
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
|
def compute_previews(self, props, input_sockets, loose_input_sockets):
|
||||||
"""Needed for the plot to regenerate in the viewer."""
|
"""Needed for the plot to regenerate in the viewer."""
|
||||||
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
|
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
|
||||||
|
|
||||||
|
@ -433,7 +433,7 @@ class VizNode(base.MaxwellSimNode):
|
||||||
def on_show_plot(
|
def on_show_plot(
|
||||||
self, managed_objs, props, input_sockets, loose_input_sockets
|
self, managed_objs, props, input_sockets, loose_input_sockets
|
||||||
) -> None:
|
) -> None:
|
||||||
log.critical('Show Plot (too many times)')
|
log.debug('Show Plot')
|
||||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||||
|
@ -456,6 +456,7 @@ class VizNode(base.MaxwellSimNode):
|
||||||
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
|
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
## TODO: CACHE entries that don't change, PLEASEEE
|
||||||
|
|
||||||
# Match Viz Type & Perform Visualization
|
# Match Viz Type & Perform Visualization
|
||||||
## -> Viz Target determines how to plot.
|
## -> Viz Target determines how to plot.
|
||||||
|
|
|
@ -207,12 +207,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
stop_propagation=True,
|
stop_propagation=True,
|
||||||
)
|
)
|
||||||
def _on_sim_node_name_changed(self, props):
|
def _on_sim_node_name_changed(self, props):
|
||||||
log.debug(
|
# log.debug(
|
||||||
'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
|
# 'Changed Sim Node Name of a "%s" to "%s" (self=%s)',
|
||||||
self.bl_idname,
|
# self.bl_idname,
|
||||||
props['sim_node_name'],
|
# props['sim_node_name'],
|
||||||
str(self),
|
# str(self),
|
||||||
)
|
# )
|
||||||
|
|
||||||
# (Re)Construct Managed Objects
|
# (Re)Construct Managed Objects
|
||||||
## -> Due to 'prev_name', the new MObjs will be renamed on construction
|
## -> Due to 'prev_name', the new MObjs will be renamed on construction
|
||||||
|
@ -360,27 +360,48 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
####################
|
####################
|
||||||
# - Socket Management
|
# - Socket Management
|
||||||
####################
|
####################
|
||||||
## TODO: Check for namespace collisions in sockets to prevent silent errors
|
|
||||||
def _prune_inactive_sockets(self):
|
def _prune_inactive_sockets(self):
|
||||||
"""Remove all "inactive" sockets from the node.
|
"""Remove all inactive sockets from the node, while only updating sockets that can be non-destructively updated.
|
||||||
|
|
||||||
A socket is considered "inactive" when it shouldn't be defined (per `self.active_socket_defs), but is present nonetheless.
|
The first step is easy: We determine, by-name, which sockets should no longer be defined, then remove them correctly.
|
||||||
|
|
||||||
|
The second step is harder: When new sockets have overlapping names, should they be removed, or should they merely have some properties updated?
|
||||||
|
Removing and re-adding the same socket is an accurate, generally robust approach, but it comes with a big caveat: **Existing node links will be cut**, even when it might semantically make sense to simply alter the socket's properties, keeping the links.
|
||||||
|
|
||||||
|
Different `bl_socket.socket_type`s can never be updated - they must be removed.
|
||||||
|
Otherwise, `SocketDef.compare(bl_socket)` allows us to granularly determine whether a particular `bl_socket` has changed with respect to the desired specification.
|
||||||
|
When the comparison is `False`, we can carefully utilize `SocketDef.init()` to re-initialize the socket, guaranteeing that the altered socket is up to the new specification.
|
||||||
"""
|
"""
|
||||||
node_tree = self.id_data
|
node_tree = self.id_data
|
||||||
for direc in ['input', 'output']:
|
for direc in ['input', 'output']:
|
||||||
all_bl_sockets = self._bl_sockets(direc)
|
bl_sockets = self._bl_sockets(direc)
|
||||||
active_bl_socket_defs = self.active_socket_defs(direc)
|
active_socket_defs = self.active_socket_defs(direc)
|
||||||
|
|
||||||
# Determine Sockets to Remove
|
# Determine Sockets to Remove
|
||||||
|
## -> Name: If the existing socket name isn't "active".
|
||||||
|
## -> Type: If the existing socket_type != "active" SocketDef.
|
||||||
bl_sockets_to_remove = [
|
bl_sockets_to_remove = [
|
||||||
bl_socket
|
bl_socket
|
||||||
for socket_name, bl_socket in all_bl_sockets.items()
|
for socket_name, bl_socket in bl_sockets.items()
|
||||||
if socket_name not in active_bl_socket_defs
|
if (
|
||||||
or socket_name
|
socket_name not in active_socket_defs
|
||||||
in (
|
or bl_socket.socket_type
|
||||||
self.loose_input_sockets
|
is not active_socket_defs[socket_name].socket_type
|
||||||
if direc == 'input'
|
)
|
||||||
else self.loose_output_sockets
|
]
|
||||||
|
|
||||||
|
# Determine Sockets to Update
|
||||||
|
## -> Name: If the existing socket name is "active".
|
||||||
|
## -> Type: If the existing socket_type == "active" SocketDef.
|
||||||
|
## -> Compare: If the existing socket differs from the SocketDef.
|
||||||
|
bl_sockets_to_update = [
|
||||||
|
bl_socket
|
||||||
|
for socket_name, bl_socket in bl_sockets.items()
|
||||||
|
if (
|
||||||
|
socket_name in active_socket_defs
|
||||||
|
and bl_socket.socket_type
|
||||||
|
is active_socket_defs[socket_name].socket_type
|
||||||
|
and not active_socket_defs[socket_name].compare(bl_socket)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -392,24 +413,25 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
## -> The NodeLinkCache needs to be adjusted manually.
|
## -> The NodeLinkCache needs to be adjusted manually.
|
||||||
node_tree.on_node_socket_removed(bl_socket)
|
node_tree.on_node_socket_removed(bl_socket)
|
||||||
|
|
||||||
# 2. Invalidate the input socket cache across all kinds.
|
# 2. Perform the removal using Blender's API.
|
||||||
|
## -> Actually removes the socket.
|
||||||
|
bl_sockets.remove(bl_socket)
|
||||||
|
|
||||||
|
# 3. Invalidate the input socket cache across all kinds.
|
||||||
## -> Prevents phantom values from remaining available.
|
## -> Prevents phantom values from remaining available.
|
||||||
|
## -> Done after socket removal to protect from race condition.
|
||||||
self._compute_input.invalidate(
|
self._compute_input.invalidate(
|
||||||
input_socket_name=bl_socket_name,
|
input_socket_name=bl_socket_name,
|
||||||
kind=...,
|
kind=...,
|
||||||
unit_system=...,
|
unit_system=...,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Perform the removal using Blender's API.
|
|
||||||
## -> Actually removes the socket.
|
|
||||||
all_bl_sockets.remove(bl_socket)
|
|
||||||
|
|
||||||
if direc == 'input':
|
if direc == 'input':
|
||||||
# 4. Run all trigger-only `on_value_changed` callbacks.
|
# 4. Run all trigger-only `on_value_changed` callbacks.
|
||||||
## -> Runs any event methods that relied on the socket.
|
## -> Runs any event methods that relied on the socket.
|
||||||
## -> Only methods that don't **require** the socket.
|
## -> Only methods that don't **require** the socket.
|
||||||
## Trigger-Only: If method loads no socket data, it runs.
|
## Only Trigger: If method loads no socket data, it runs.
|
||||||
## `optional`: If method optional-loads socket, it runs.
|
## Optional: If method optional-loads socket, it runs.
|
||||||
triggered_event_methods = [
|
triggered_event_methods = [
|
||||||
event_method
|
event_method
|
||||||
for event_method in self.filtered_event_methods_by_event(
|
for event_method in self.filtered_event_methods_by_event(
|
||||||
|
@ -419,32 +441,52 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
not in event_method.callback_info.must_load_sockets
|
not in event_method.callback_info.must_load_sockets
|
||||||
]
|
]
|
||||||
for event_method in triggered_event_methods:
|
for event_method in triggered_event_methods:
|
||||||
log.critical(
|
|
||||||
'%s: Running %s',
|
|
||||||
self.sim_node_name,
|
|
||||||
str(event_method),
|
|
||||||
)
|
|
||||||
event_method(self)
|
event_method(self)
|
||||||
|
|
||||||
|
# Update Sockets
|
||||||
|
for bl_socket in bl_sockets_to_update:
|
||||||
|
bl_socket_name = bl_socket.name
|
||||||
|
socket_def = active_socket_defs[bl_socket_name]
|
||||||
|
|
||||||
|
# 1. Pretend to Initialize for the First Time
|
||||||
|
## -> NOTE: The socket's caches will be completely regenerated.
|
||||||
|
## -> NOTE: A full FlowKind update will occur, but only one.
|
||||||
|
bl_socket.is_initializing = True
|
||||||
|
socket_def.preinit(bl_socket)
|
||||||
|
socket_def.init(bl_socket)
|
||||||
|
socket_def.postinit(bl_socket)
|
||||||
|
|
||||||
|
# 2. Re-Test Socket Capabilities
|
||||||
|
## -> Factors influencing CapabilitiesFlow may have changed.
|
||||||
|
## -> Therefore, we must re-test all link capabilities.
|
||||||
|
bl_socket.remove_invalidated_links()
|
||||||
|
|
||||||
|
# 3. Invalidate the input socket cache across all kinds.
|
||||||
|
## -> Prevents phantom values from remaining available.
|
||||||
|
self._compute_input.invalidate(
|
||||||
|
input_socket_name=bl_socket_name,
|
||||||
|
kind=...,
|
||||||
|
unit_system=...,
|
||||||
|
)
|
||||||
|
|
||||||
def _add_new_active_sockets(self):
|
def _add_new_active_sockets(self):
|
||||||
"""Add and initialize all "active" sockets that aren't on the node.
|
"""Add and initialize all "active" sockets that aren't on the node.
|
||||||
|
|
||||||
Existing sockets within the given direction are not re-created.
|
Existing sockets within the given direction are not re-created.
|
||||||
"""
|
"""
|
||||||
for direc in ['input', 'output']:
|
for direc in ['input', 'output']:
|
||||||
all_bl_sockets = self._bl_sockets(direc)
|
bl_sockets = self._bl_sockets(direc)
|
||||||
active_bl_socket_defs = self.active_socket_defs(direc)
|
active_socket_defs = self.active_socket_defs(direc)
|
||||||
|
|
||||||
# Define BL Sockets
|
# Define BL Sockets
|
||||||
created_sockets = {}
|
created_sockets = {}
|
||||||
for socket_name, socket_def in active_bl_socket_defs.items():
|
for socket_name, socket_def in active_socket_defs.items():
|
||||||
# Skip Existing Sockets
|
# Skip Existing Sockets
|
||||||
if socket_name in all_bl_sockets:
|
if socket_name in bl_sockets:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create BL Socket from Socket
|
# Create BL Socket from Socket
|
||||||
## Set 'display_shape' from 'socket_shape'
|
bl_sockets.new(
|
||||||
all_bl_sockets.new(
|
|
||||||
str(socket_def.socket_type.value),
|
str(socket_def.socket_type.value),
|
||||||
socket_name,
|
socket_name,
|
||||||
)
|
)
|
||||||
|
@ -454,9 +496,9 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
|
|
||||||
# Initialize Just-Created BL Sockets
|
# Initialize Just-Created BL Sockets
|
||||||
for bl_socket_name, socket_def in created_sockets.items():
|
for bl_socket_name, socket_def in created_sockets.items():
|
||||||
socket_def.preinit(all_bl_sockets[bl_socket_name])
|
socket_def.preinit(bl_sockets[bl_socket_name])
|
||||||
socket_def.init(all_bl_sockets[bl_socket_name])
|
socket_def.init(bl_sockets[bl_socket_name])
|
||||||
socket_def.postinit(all_bl_sockets[bl_socket_name])
|
socket_def.postinit(bl_sockets[bl_socket_name])
|
||||||
|
|
||||||
# Invalidate Cached NoFlows
|
# Invalidate Cached NoFlows
|
||||||
self._compute_input.invalidate(
|
self._compute_input.invalidate(
|
||||||
|
@ -637,9 +679,10 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
lambda a, b: a | b,
|
lambda a, b: a | b,
|
||||||
[
|
[
|
||||||
self._compute_input(
|
self._compute_input(
|
||||||
socket, kind=ct.FlowKind.Previews, unit_system=None
|
socket_name,
|
||||||
|
kind=ct.FlowKind.Previews,
|
||||||
)
|
)
|
||||||
for socket in [bl_socket.name for bl_socket in self.inputs]
|
for socket_name in [bl_socket.name for bl_socket in self.inputs]
|
||||||
],
|
],
|
||||||
ct.PreviewsFlow(),
|
ct.PreviewsFlow(),
|
||||||
)
|
)
|
||||||
|
@ -897,9 +940,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
)
|
)
|
||||||
altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
|
altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
|
||||||
|
|
||||||
|
# Clear Output Socket Cache(s)
|
||||||
|
## -> We aggregate it manually, so it needs a special invl.
|
||||||
|
## -> See self.compute_output()
|
||||||
|
if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
|
||||||
|
for out_sckname in self.outputs.keys(): # noqa: SIM118
|
||||||
|
self.compute_output.invalidate(
|
||||||
|
output_socket_name=out_sckname,
|
||||||
|
kind=ct.FlowKind.Previews,
|
||||||
|
)
|
||||||
|
altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
|
||||||
|
|
||||||
# Run Triggered Event Methods
|
# Run Triggered Event Methods
|
||||||
## -> A triggered event method may request to stop propagation.
|
## -> A triggered event method may request to stop propagation.
|
||||||
## -> A triggered event method may request to stop propagation.
|
|
||||||
stop_propagation = False
|
stop_propagation = False
|
||||||
triggered_event_methods = self.filtered_event_methods_by_event(
|
triggered_event_methods = self.filtered_event_methods_by_event(
|
||||||
event, (socket_name, prop_names, None)
|
event, (socket_name, prop_names, None)
|
||||||
|
|
|
@ -266,7 +266,7 @@ def event_decorator( # noqa: PLR0913
|
||||||
)
|
)
|
||||||
|
|
||||||
# Loose Sockets
|
# Loose Sockets
|
||||||
## Compute All Loose Input Sockets
|
## -> Determined by the active_kind of each loose input socket.
|
||||||
method_kw_args |= (
|
method_kw_args |= (
|
||||||
{
|
{
|
||||||
'loose_input_sockets': {
|
'loose_input_sockets': {
|
||||||
|
|
|
@ -29,6 +29,8 @@ from ... import base, events
|
||||||
|
|
||||||
|
|
||||||
class ScientificConstantNode(base.MaxwellSimNode):
|
class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
|
"""A well-known constant usable as itself, or as a symbol."""
|
||||||
|
|
||||||
node_type = ct.NodeType.ScientificConstant
|
node_type = ct.NodeType.ScientificConstant
|
||||||
bl_label = 'Scientific Constant'
|
bl_label = 'Scientific Constant'
|
||||||
|
|
||||||
|
@ -88,6 +90,11 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
|
def draw_label(self):
|
||||||
|
if self.sci_constant_str:
|
||||||
|
return f'Const: {self.sci_constant_str}'
|
||||||
|
return self.bl_label
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
col.prop(self, self.blfields['sci_constant_str'], text='')
|
col.prop(self, self.blfields['sci_constant_str'], text='')
|
||||||
|
|
||||||
|
@ -156,6 +163,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
props={'sci_constant', 'sci_constant_sym'},
|
props={'sci_constant', 'sci_constant_sym'},
|
||||||
)
|
)
|
||||||
def compute_lazy_func(self, props) -> typ.Any:
|
def compute_lazy_func(self, props) -> typ.Any:
|
||||||
|
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
|
||||||
sci_constant = props['sci_constant']
|
sci_constant = props['sci_constant']
|
||||||
sci_constant_sym = props['sci_constant_sym']
|
sci_constant_sym = props['sci_constant_sym']
|
||||||
|
|
||||||
|
@ -165,6 +173,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
|
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
|
||||||
),
|
),
|
||||||
func_args=[sci_constant_sym],
|
func_args=[sci_constant_sym],
|
||||||
|
func_output=sci_constant_sym,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
@ -175,6 +184,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
props={'sci_constant_sym'},
|
props={'sci_constant_sym'},
|
||||||
)
|
)
|
||||||
def compute_info(self, props: dict) -> typ.Any:
|
def compute_info(self, props: dict) -> typ.Any:
|
||||||
|
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
|
||||||
sci_constant_sym = props['sci_constant_sym']
|
sci_constant_sym = props['sci_constant_sym']
|
||||||
|
|
||||||
if sci_constant_sym is not None:
|
if sci_constant_sym is not None:
|
||||||
|
@ -193,8 +203,12 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
if sci_constant is not None and sci_constant_sym is not None:
|
if sci_constant is not None and sci_constant_sym is not None:
|
||||||
return ct.ParamsFlow(
|
return ct.ParamsFlow(
|
||||||
arg_targets=[sci_constant_sym],
|
arg_targets=[sci_constant_sym],
|
||||||
func_args=[sci_constant],
|
func_args=[sci_constant_sym.sp_symbol],
|
||||||
is_differentiable=True,
|
symbols={sci_constant_sym},
|
||||||
|
).realize_partial(
|
||||||
|
{
|
||||||
|
sci_constant_sym: sci_constant,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
|
@ -216,10 +216,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
props={'symbol'},
|
props={'symbol'},
|
||||||
)
|
)
|
||||||
def compute_lazy_func(self, props) -> typ.Any:
|
def compute_lazy_func(self, props) -> typ.Any:
|
||||||
sp_sym = props['symbol'].sp_symbol
|
sym = props['symbol']
|
||||||
return ct.FuncFlow(
|
return ct.FuncFlow(
|
||||||
func=sp.lambdify(sp_sym, sp_sym, 'jax'),
|
func=sp.lambdify(sym.sp_symbol_matsym, sym.sp_symbol_matsym, 'jax'),
|
||||||
func_args=[sp_sym],
|
func_args=[sym],
|
||||||
|
func_output=sym,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -235,6 +236,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def compute_info(self, props) -> typ.Any:
|
def compute_info(self, props) -> typ.Any:
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
|
dims={props['symbol']: None},
|
||||||
output=props['symbol'],
|
output=props['symbol'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -251,9 +253,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
arg_targets=[sym],
|
arg_targets=[sym],
|
||||||
func_args=[sym.sp_symbol],
|
func_args=[sym.sp_symbol],
|
||||||
symbols={sym},
|
symbols={sym},
|
||||||
is_differentiable=(
|
|
||||||
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -198,9 +198,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Func,
|
kind=ct.FlowKind.Func,
|
||||||
# Loaded
|
# Loaded
|
||||||
|
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
|
||||||
input_sockets={'File Path'},
|
input_sockets={'File Path'},
|
||||||
)
|
)
|
||||||
def compute_func(self, input_sockets) -> td.Simulation:
|
def compute_func(self, props, input_sockets) -> td.Simulation:
|
||||||
"""Declare a lazy, composable function that returns the loaded data.
|
"""Declare a lazy, composable function that returns the loaded data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -209,6 +210,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
file_path = input_sockets['File Path']
|
file_path = input_sockets['File Path']
|
||||||
has_file_path = not ct.FlowSignal.check(file_path)
|
has_file_path = not ct.FlowSignal.check(file_path)
|
||||||
|
|
||||||
|
func_output = sim_symbols.SimSymbol(
|
||||||
|
sym_name=props['output_name'],
|
||||||
|
mathtype=props['output_mathtype'],
|
||||||
|
physical_type=props['output_physical_type'],
|
||||||
|
unit=props['output_unit'],
|
||||||
|
)
|
||||||
if has_file_path and file_path is not None:
|
if has_file_path and file_path is not None:
|
||||||
data_file_format = ct.DataFileFormat.from_path(file_path)
|
data_file_format = ct.DataFileFormat.from_path(file_path)
|
||||||
if data_file_format is not None:
|
if data_file_format is not None:
|
||||||
|
@ -217,13 +224,18 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
if data_file_format.loader_is_jax_compatible:
|
if data_file_format.loader_is_jax_compatible:
|
||||||
return ct.FuncFlow(
|
return ct.FuncFlow(
|
||||||
func=lambda: data_file_format.loader(file_path),
|
func=lambda: data_file_format.loader(file_path),
|
||||||
|
func_output=func_output,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# No Jax Compatibility: Eager Data Loading
|
# No Jax Compatibility: Eager Data Loading
|
||||||
## -> Load the data now and bind it.
|
## -> Load the data now and bind it.
|
||||||
data = data_file_format.loader(file_path)
|
data = data_file_format.loader(file_path)
|
||||||
return ct.FuncFlow(func=lambda: data, supports_jax=True)
|
return ct.FuncFlow(
|
||||||
|
func=lambda: data,
|
||||||
|
func_output=func_output,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
|
@ -86,24 +86,47 @@ class ViewerNode(base.MaxwellSimNode):
|
||||||
# - Properties: Computed FlowKinds
|
# - Properties: Computed FlowKinds
|
||||||
####################
|
####################
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
socket_name='Any',
|
# Trigger
|
||||||
|
prop_name='console_print_kind',
|
||||||
|
# Loaded
|
||||||
|
props={'auto_expr', 'console_print_kind'},
|
||||||
)
|
)
|
||||||
def on_input_changed(self) -> None:
|
def on_print_kind_changed(self, props) -> None:
|
||||||
|
self.inputs['Any'].active_kind = props['console_print_kind']
|
||||||
|
|
||||||
|
if props['auto_expr']:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
'input_' + props['console_print_kind'].property_name,
|
||||||
|
bl_cache.Signal.InvalidateCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.on_value_changed(
|
||||||
|
# Trugger
|
||||||
|
socket_name='Any',
|
||||||
|
# Loaded
|
||||||
|
props={'auto_expr', 'console_print_kind'},
|
||||||
|
)
|
||||||
|
def on_input_changed(self, props) -> None:
|
||||||
"""Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed.
|
"""Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed.
|
||||||
|
|
||||||
Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s.
|
Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s.
|
||||||
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
|
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
|
||||||
"""
|
"""
|
||||||
for flow_kind in list(ct.FlowKind):
|
# Invalidate PreviewsFlow
|
||||||
flow = self.get_flow(
|
setattr(
|
||||||
flow_kind, always_load=flow_kind is ct.FlowKind.Previews
|
self,
|
||||||
|
'input_' + ct.FlowKind.Previews.property_name,
|
||||||
|
bl_cache.Signal.InvalidateCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalidate PreviewsFlow
|
||||||
|
if props['auto_expr']:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
'input_' + props['console_print_kind'].property_name,
|
||||||
|
bl_cache.Signal.InvalidateCache,
|
||||||
)
|
)
|
||||||
if flow is not None:
|
|
||||||
setattr(
|
|
||||||
self,
|
|
||||||
'input_' + flow_kind.property_name,
|
|
||||||
bl_cache.Signal.InvalidateCache,
|
|
||||||
)
|
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
|
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
|
||||||
def input_capabilities(self) -> ct.CapabilitiesFlow | None:
|
def input_capabilities(self) -> ct.CapabilitiesFlow | None:
|
||||||
|
|
|
@ -14,8 +14,10 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
|
import bpy
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache
|
from blender_maxwell.utils import bl_cache
|
||||||
|
@ -26,6 +28,8 @@ from .. import base, events
|
||||||
|
|
||||||
|
|
||||||
class CombineNode(base.MaxwellSimNode):
|
class CombineNode(base.MaxwellSimNode):
|
||||||
|
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
|
||||||
|
|
||||||
node_type = ct.NodeType.Combine
|
node_type = ct.NodeType.Combine
|
||||||
bl_label = 'Combine'
|
bl_label = 'Combine'
|
||||||
|
|
||||||
|
@ -33,112 +37,222 @@ class CombineNode(base.MaxwellSimNode):
|
||||||
# - Sockets
|
# - Sockets
|
||||||
####################
|
####################
|
||||||
input_socket_sets: typ.ClassVar = {
|
input_socket_sets: typ.ClassVar = {
|
||||||
'Maxwell Sources': {},
|
'Sources': {},
|
||||||
'Maxwell Structures': {},
|
'Structures': {},
|
||||||
'Maxwell Monitors': {},
|
'Monitors': {},
|
||||||
}
|
}
|
||||||
output_socket_sets: typ.ClassVar = {
|
output_socket_sets: typ.ClassVar = {
|
||||||
'Maxwell Sources': {
|
'Sources': {
|
||||||
'Sources': sockets.MaxwellSourceSocketDef(
|
'Sources': sockets.MaxwellSourceSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
'Maxwell Structures': {
|
'Structures': {
|
||||||
'Structures': sockets.MaxwellStructureSocketDef(
|
'Structures': sockets.MaxwellStructureSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
'Maxwell Monitors': {
|
'Monitors': {
|
||||||
'Monitors': sockets.MaxwellMonitorSocketDef(
|
'Monitors': sockets.MaxwellMonitorSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Draw
|
# - Properties
|
||||||
####################
|
####################
|
||||||
amount: int = bl_cache.BLField(2, abs_min=1, prop_ui=True)
|
concatenate_first: bool = bl_cache.BLField(False)
|
||||||
|
value_or_func: ct.FlowKind = bl_cache.BLField(
|
||||||
|
enum_cb=lambda self, _: self._value_or_func(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _value_or_func(self):
|
||||||
|
return [
|
||||||
|
flow_kind.bl_enum_element(i)
|
||||||
|
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
|
||||||
|
]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Draw
|
# - Draw
|
||||||
####################
|
####################
|
||||||
def draw_props(self, context, layout):
|
def draw_props(self, _, layout: bpy.types.UILayout):
|
||||||
layout.prop(self, self.blfields['amount'], text='')
|
layout.prop(self, self.blfields['value_or_func'], text='')
|
||||||
|
|
||||||
|
if self.value_or_func is ct.FlowKind.Value:
|
||||||
|
layout.prop(
|
||||||
|
self,
|
||||||
|
self.blfields['concatenate_first'],
|
||||||
|
text='Concatenate',
|
||||||
|
toggle=True,
|
||||||
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Events
|
# - Events
|
||||||
####################
|
####################
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
# Trigger
|
any_loose_input_socket=True,
|
||||||
prop_name={'active_socket_set', 'amount'},
|
prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
|
||||||
props={'active_socket_set', 'amount'},
|
|
||||||
run_on_init=True,
|
run_on_init=True,
|
||||||
|
# Loaded
|
||||||
|
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
|
||||||
)
|
)
|
||||||
def on_inputs_changed(self, props):
|
def on_inputs_changed(self, props) -> None:
|
||||||
if props['active_socket_set'] == 'Maxwell Sources':
|
"""Always create one extra loose input socket."""
|
||||||
if (
|
active_socket_set = props['active_socket_set']
|
||||||
not self.loose_input_sockets
|
|
||||||
or not next(iter(self.loose_input_sockets)).startswith('Source')
|
|
||||||
or len(self.loose_input_sockets) != props['amount']
|
|
||||||
):
|
|
||||||
self.loose_input_sockets = {
|
|
||||||
f'Source #{i}': sockets.MaxwellSourceSocketDef()
|
|
||||||
for i in range(props['amount'])
|
|
||||||
}
|
|
||||||
|
|
||||||
elif props['active_socket_set'] == 'Maxwell Structures':
|
# Deduce SocketDef
|
||||||
if (
|
## -> Cheat by retrieving the class from the output sockets.
|
||||||
not self.loose_input_sockets
|
SocketDef = self.output_socket_sets[active_socket_set][
|
||||||
or not next(iter(self.loose_input_sockets)).startswith('Structure')
|
active_socket_set
|
||||||
or len(self.loose_input_sockets) != props['amount']
|
].__class__
|
||||||
):
|
|
||||||
self.loose_input_sockets = {
|
# Deduce Current "Filled"
|
||||||
f'Structure #{i}': sockets.MaxwellStructureSocketDef()
|
## -> The first linked socket from the end bounds the "filled" region.
|
||||||
for i in range(props['amount'])
|
## -> The length of that region, plus one, will be the new amount.
|
||||||
}
|
reverse_linked_idxs = [
|
||||||
elif props['active_socket_set'] == 'Maxwell Monitors':
|
i
|
||||||
if (
|
for i, bl_socket in enumerate(reversed(self.inputs.values()))
|
||||||
not self.loose_input_sockets
|
if bl_socket.is_linked
|
||||||
or not next(iter(self.loose_input_sockets)).startswith('Monitor')
|
]
|
||||||
or len(self.loose_input_sockets) != props['amount']
|
current_filled = len(self.inputs) - (
|
||||||
):
|
reverse_linked_idxs[0] if reverse_linked_idxs else len(self.inputs)
|
||||||
self.loose_input_sockets = {
|
)
|
||||||
f'Monitor #{i}': sockets.MaxwellMonitorSocketDef()
|
new_amount = current_filled + 1
|
||||||
for i in range(props['amount'])
|
|
||||||
}
|
# Deduce SocketDef | Current Amount
|
||||||
elif self.loose_input_sockets:
|
concatenate_first = props['concatenate_first']
|
||||||
self.loose_input_sockets = {}
|
flow_kind = props['value_or_func']
|
||||||
|
|
||||||
|
self.loose_input_sockets = {
|
||||||
|
'#0': SocketDef(
|
||||||
|
active_kind=flow_kind
|
||||||
|
if flow_kind is ct.FlowKind.Func or not concatenate_first
|
||||||
|
else ct.FlowKind.Array
|
||||||
|
)
|
||||||
|
} | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output Socket Computation
|
# - FlowKind.Array|Func
|
||||||
|
####################
|
||||||
|
def compute_combined(
|
||||||
|
self,
|
||||||
|
loose_input_sockets,
|
||||||
|
input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
|
||||||
|
output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
|
||||||
|
) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
|
||||||
|
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
|
||||||
|
|
||||||
|
If there is no output, or the flows aren't compatible, return `FlowPending`.
|
||||||
|
"""
|
||||||
|
match (input_flow_kind, output_flow_kind):
|
||||||
|
case (ct.FlowKind.Value, ct.FlowKind.Array):
|
||||||
|
value_flows = [
|
||||||
|
inp
|
||||||
|
for inp in loose_input_sockets.values()
|
||||||
|
if not ct.FlowSignal.check(inp)
|
||||||
|
]
|
||||||
|
if value_flows:
|
||||||
|
return value_flows
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
case (ct.FlowKind.Func, ct.FlowKind.Func):
|
||||||
|
func_flows = [
|
||||||
|
inp
|
||||||
|
for inp in loose_input_sockets.values()
|
||||||
|
if not ct.FlowSignal.check(inp)
|
||||||
|
]
|
||||||
|
if func_flows:
|
||||||
|
return functools.reduce(
|
||||||
|
lambda a, b: a | b,
|
||||||
|
func_flows,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Output: Sources
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Sources',
|
'Sources',
|
||||||
kind=ct.FlowKind.Array,
|
kind=ct.FlowKind.Array,
|
||||||
all_loose_input_sockets=True,
|
all_loose_input_sockets=True,
|
||||||
props={'amount'},
|
props={'value_or_func'},
|
||||||
)
|
)
|
||||||
def compute_sources(self, loose_input_sockets, props) -> sp.Expr:
|
def compute_sources_array(
|
||||||
return [loose_input_sockets[f'Source #{i}'] for i in range(props['amount'])]
|
self, props, loose_input_sockets
|
||||||
|
) -> list[typ.Any] | ct.FlowSignal:
|
||||||
|
"""Compute sources."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Sources',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
all_loose_input_sockets=True,
|
||||||
|
props={'value_or_func'},
|
||||||
|
)
|
||||||
|
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||||
|
"""Compute (lazy) sources."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||||
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Output: Structures
|
||||||
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Structures',
|
'Structures',
|
||||||
kind=ct.FlowKind.Array,
|
kind=ct.FlowKind.Array,
|
||||||
all_loose_input_sockets=True,
|
all_loose_input_sockets=True,
|
||||||
props={'amount'},
|
props={'value_or_func'},
|
||||||
)
|
)
|
||||||
def compute_structures(self, loose_input_sockets, props) -> sp.Expr:
|
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
|
||||||
return [loose_input_sockets[f'Structure #{i}'] for i in range(props['amount'])]
|
"""Compute structures."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Structures',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
all_loose_input_sockets=True,
|
||||||
|
props={'value_or_func'},
|
||||||
|
)
|
||||||
|
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||||
|
"""Compute (lazy) structures."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||||
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Output: Monitors
|
||||||
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Monitors',
|
'Monitors',
|
||||||
kind=ct.FlowKind.Array,
|
kind=ct.FlowKind.Array,
|
||||||
all_loose_input_sockets=True,
|
all_loose_input_sockets=True,
|
||||||
props={'amount'},
|
props={'value_or_func'},
|
||||||
)
|
)
|
||||||
def compute_monitors(self, loose_input_sockets, props) -> sp.Expr:
|
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
|
||||||
return [loose_input_sockets[f'Monitor #{i}'] for i in range(props['amount'])]
|
"""Compute monitors."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Monitors',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
all_loose_input_sockets=True,
|
||||||
|
props={'value_or_func'},
|
||||||
|
)
|
||||||
|
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
|
||||||
|
"""Compute (lazy) monitors."""
|
||||||
|
return self.compute_combined(
|
||||||
|
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -14,17 +14,26 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""Implements `FDTDSimNode`."""
|
||||||
|
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import sympy as sp
|
import bpy
|
||||||
import tidy3d as td
|
import tidy3d as td
|
||||||
|
import tidy3d.plugins.adjoint as tdadj
|
||||||
|
|
||||||
|
from blender_maxwell.utils import bl_cache, logger
|
||||||
|
|
||||||
from ... import contracts as ct
|
from ... import contracts as ct
|
||||||
from ... import sockets
|
from ... import sockets
|
||||||
from .. import base, events
|
from .. import base, events
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FDTDSimNode(base.MaxwellSimNode):
|
class FDTDSimNode(base.MaxwellSimNode):
|
||||||
|
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
|
||||||
|
|
||||||
node_type = ct.NodeType.FDTDSim
|
node_type = ct.NodeType.FDTDSim
|
||||||
bl_label = 'FDTD Simulation'
|
bl_label = 'FDTD Simulation'
|
||||||
|
|
||||||
|
@ -35,51 +44,255 @@ class FDTDSimNode(base.MaxwellSimNode):
|
||||||
'BCs': sockets.MaxwellBoundCondsSocketDef(),
|
'BCs': sockets.MaxwellBoundCondsSocketDef(),
|
||||||
'Domain': sockets.MaxwellSimDomainSocketDef(),
|
'Domain': sockets.MaxwellSimDomainSocketDef(),
|
||||||
'Sources': sockets.MaxwellSourceSocketDef(
|
'Sources': sockets.MaxwellSourceSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
'Structures': sockets.MaxwellStructureSocketDef(
|
'Structures': sockets.MaxwellStructureSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
'Monitors': sockets.MaxwellMonitorSocketDef(
|
'Monitors': sockets.MaxwellMonitorSocketDef(
|
||||||
is_list=True,
|
active_kind=ct.FlowKind.Array,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
output_sockets: typ.ClassVar = {
|
output_socket_sets: typ.ClassVar = {
|
||||||
'Sim': sockets.MaxwellFDTDSimSocketDef(),
|
'Single': {
|
||||||
|
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
|
||||||
|
},
|
||||||
|
'Batch': {
|
||||||
|
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array),
|
||||||
|
},
|
||||||
|
'Lazy': {
|
||||||
|
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output Socket Computation
|
# - Properties
|
||||||
|
####################
|
||||||
|
differentiable: bool = bl_cache.BLField(False)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - UI
|
||||||
|
####################
|
||||||
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||||
|
layout.prop(
|
||||||
|
self,
|
||||||
|
self.blfields['differentiable'],
|
||||||
|
text='Differentiable',
|
||||||
|
toggle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Events
|
||||||
|
####################
|
||||||
|
@events.on_value_changed(
|
||||||
|
# Trigger
|
||||||
|
socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||||
|
run_on_init=True,
|
||||||
|
# Loaded
|
||||||
|
props={'active_socket_set'},
|
||||||
|
output_sockets={'Sim'},
|
||||||
|
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||||
|
)
|
||||||
|
def on_any_changed(self, props, output_sockets) -> None:
|
||||||
|
"""Create loose input sockets."""
|
||||||
|
params = output_sockets['Sim']
|
||||||
|
has_params = not ct.FlowSignal.check(params)
|
||||||
|
|
||||||
|
# Declare Loose Sockets that Realize Symbols
|
||||||
|
## -> This happens if Params contains not-yet-realized symbols.
|
||||||
|
active_socket_set = props['active_socket_set']
|
||||||
|
if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols:
|
||||||
|
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
|
||||||
|
self.loose_input_sockets = {
|
||||||
|
sym.name: sockets.ExprSocketDef(
|
||||||
|
**(
|
||||||
|
expr_info
|
||||||
|
| {
|
||||||
|
'active_kind': ct.FlowKind.Value,
|
||||||
|
'use_value_range_swapper': (
|
||||||
|
active_socket_set == 'Value'
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for sym, expr_info in params.sym_expr_infos.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.loose_input_sockets:
|
||||||
|
self.loose_input_sockets = {}
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Value
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Sim',
|
'Sim',
|
||||||
kind=ct.FlowKind.Value,
|
kind=ct.FlowKind.Value,
|
||||||
|
# Loaded
|
||||||
|
props={'differentiable'},
|
||||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Sources': ct.FlowKind.Array,
|
'Sources': ct.FlowKind.Array,
|
||||||
'Structures': ct.FlowKind.Array,
|
'Structures': ct.FlowKind.Array,
|
||||||
'Domain': ct.FlowKind.Value,
|
|
||||||
'BCs': ct.FlowKind.Value,
|
|
||||||
'Monitors': ct.FlowKind.Array,
|
'Monitors': ct.FlowKind.Array,
|
||||||
},
|
},
|
||||||
|
output_sockets={'Sim'},
|
||||||
|
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||||
)
|
)
|
||||||
def compute_fdtd_sim(self, input_sockets: dict) -> sp.Expr:
|
def compute_fdtd_sim_value(
|
||||||
if any(ct.FlowSignal.check(inp) for inp in input_sockets):
|
self, props, input_sockets, output_sockets
|
||||||
return ct.FlowSignal.FlowPending
|
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||||
|
"""Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable."""
|
||||||
sim_domain = input_sockets['Domain']
|
sim_domain = input_sockets['Domain']
|
||||||
sources = input_sockets['Sources']
|
sources = input_sockets['Sources']
|
||||||
structures = input_sockets['Structures']
|
structures = input_sockets['Structures']
|
||||||
bounds = input_sockets['BCs']
|
bounds = input_sockets['BCs']
|
||||||
monitors = input_sockets['Monitors']
|
monitors = input_sockets['Monitors']
|
||||||
return td.Simulation(
|
output_params = output_sockets['Sim']
|
||||||
**sim_domain,
|
|
||||||
structures=structures,
|
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||||
sources=sources,
|
has_sources = not ct.FlowSignal.check(sources)
|
||||||
monitors=monitors,
|
has_structures = not ct.FlowSignal.check(structures)
|
||||||
boundary_spec=bounds,
|
has_bounds = not ct.FlowSignal.check(bounds)
|
||||||
)
|
has_monitors = not ct.FlowSignal.check(monitors)
|
||||||
## TODO: Visualize the boundary conditions on top of the sim domain
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
differentiable = props['differentiable']
|
||||||
|
if (
|
||||||
|
has_sim_domain
|
||||||
|
and has_sources
|
||||||
|
and has_structures
|
||||||
|
and has_bounds
|
||||||
|
and has_monitors
|
||||||
|
and has_output_params
|
||||||
|
and not differentiable
|
||||||
|
):
|
||||||
|
return td.Simulation(
|
||||||
|
**sim_domain,
|
||||||
|
sources=sources,
|
||||||
|
structures=structures,
|
||||||
|
boundary_spec=bounds,
|
||||||
|
monitors=monitors,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Func
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Sim',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
# Loaded
|
||||||
|
props={'differentiable'},
|
||||||
|
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Sources': ct.FlowKind.Func,
|
||||||
|
'Structures': ct.FlowKind.Func,
|
||||||
|
'Monitors': ct.FlowKind.Func,
|
||||||
|
},
|
||||||
|
output_sockets={'Sim'},
|
||||||
|
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||||
|
)
|
||||||
|
def compute_fdtd_sim_func(
|
||||||
|
self, props, input_sockets, output_sockets
|
||||||
|
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||||
|
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||||
|
sim_domain = input_sockets['Domain']
|
||||||
|
sources = input_sockets['Sources']
|
||||||
|
structures = input_sockets['Structures']
|
||||||
|
bounds = input_sockets['BCs']
|
||||||
|
monitors = input_sockets['Monitors']
|
||||||
|
output_params = output_sockets['Sim']
|
||||||
|
|
||||||
|
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||||
|
has_sources = not ct.FlowSignal.check(sources)
|
||||||
|
has_structures = not ct.FlowSignal.check(structures)
|
||||||
|
has_bounds = not ct.FlowSignal.check(bounds)
|
||||||
|
has_monitors = not ct.FlowSignal.check(monitors)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_sim_domain
|
||||||
|
and has_sources
|
||||||
|
and has_structures
|
||||||
|
and has_bounds
|
||||||
|
and has_monitors
|
||||||
|
and has_output_params
|
||||||
|
):
|
||||||
|
differentiable = props['differentiable']
|
||||||
|
if differentiable:
|
||||||
|
return (
|
||||||
|
sim_domain | sources | structures | bounds | monitors
|
||||||
|
).compose_within(
|
||||||
|
enclosing_func=lambda els: tdadj.JaxSimulation(
|
||||||
|
**els[0],
|
||||||
|
sources=els[1],
|
||||||
|
structures=els[2]['static'],
|
||||||
|
input_structures=els[2]['differentiable'],
|
||||||
|
boundary_spec=els[3],
|
||||||
|
monitors=els[4]['static'],
|
||||||
|
output_monitors=els[4]['differentiable'],
|
||||||
|
),
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sim_domain | sources | structures | bounds | monitors
|
||||||
|
).compose_within(
|
||||||
|
enclosing_func=lambda els: td.Simulation(
|
||||||
|
**els[0],
|
||||||
|
sources=els[1],
|
||||||
|
structures=els[2],
|
||||||
|
boundary_spec=els[3],
|
||||||
|
monitors=els[4],
|
||||||
|
),
|
||||||
|
supports_jax=False,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Sim',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
# Loaded
|
||||||
|
props={'differentiable'},
|
||||||
|
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Sources': ct.FlowKind.Params,
|
||||||
|
'Structures': ct.FlowKind.Params,
|
||||||
|
'Monitors': ct.FlowKind.Params,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_fdtd_sim_params(
|
||||||
|
self, props, input_sockets
|
||||||
|
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||||
|
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||||
|
sim_domain = input_sockets['Domain']
|
||||||
|
sources = input_sockets['Sources']
|
||||||
|
structures = input_sockets['Structures']
|
||||||
|
bounds = input_sockets['BCs']
|
||||||
|
monitors = input_sockets['Monitors']
|
||||||
|
|
||||||
|
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||||
|
has_sources = not ct.FlowSignal.check(sources)
|
||||||
|
has_structures = not ct.FlowSignal.check(structures)
|
||||||
|
has_bounds = not ct.FlowSignal.check(bounds)
|
||||||
|
has_monitors = not ct.FlowSignal.check(monitors)
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_sim_domain
|
||||||
|
and has_sources
|
||||||
|
and has_structures
|
||||||
|
and has_bounds
|
||||||
|
and has_monitors
|
||||||
|
):
|
||||||
|
# Determine Differentiable Match
|
||||||
|
## -> 'structures' is diff when **any** are diff.
|
||||||
|
## -> 'monitors' is also diff when **any** are diff.
|
||||||
|
## -> Only parameters through diff structs can be diff'ed by.
|
||||||
|
## -> Similarly, only diff monitors will have gradients computed.
|
||||||
|
return sim_domain | sources | structures | bounds | monitors
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""Implements `SimDomainNode`."""
|
||||||
|
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
@ -31,6 +33,8 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SimDomainNode(base.MaxwellSimNode):
|
class SimDomainNode(base.MaxwellSimNode):
|
||||||
|
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
|
||||||
|
|
||||||
node_type = ct.NodeType.SimDomain
|
node_type = ct.NodeType.SimDomain
|
||||||
bl_label = 'Sim Domain'
|
bl_label = 'Sim Domain'
|
||||||
use_sim_node_name = True
|
use_sim_node_name = True
|
||||||
|
@ -69,26 +73,109 @@ class SimDomainNode(base.MaxwellSimNode):
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Outputs
|
# - FlowKind.Value
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Domain',
|
'Domain',
|
||||||
|
kind=ct.FlowKind.Value,
|
||||||
|
# Loaded
|
||||||
|
output_sockets={'Domain'},
|
||||||
|
output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||||
|
)
|
||||||
|
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
|
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||||
|
output_func = output_sockets['Domain'][ct.FlowKind.Func]
|
||||||
|
output_params = output_sockets['Domain'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
has_output_func = not ct.FlowSignal.check(output_func)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if has_output_func and has_output_params and not output_params.symbols:
|
||||||
|
return output_func.realize(output_params)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Func
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Domain',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
# Loaded
|
||||||
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
|
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
|
||||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
input_socket_kinds={
|
||||||
scale_input_sockets={
|
'Duration': ct.FlowKind.Func,
|
||||||
'Duration': 'Tidy3DUnits',
|
'Center': ct.FlowKind.Func,
|
||||||
'Center': 'Tidy3DUnits',
|
'Size': ct.FlowKind.Func,
|
||||||
'Size': 'Tidy3DUnits',
|
'Grid': ct.FlowKind.Func,
|
||||||
|
'Ambient Medium': ct.FlowKind.Func,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_domain(self, input_sockets, unit_systems) -> sp.Expr:
|
def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
return {
|
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||||
'run_time': input_sockets['Duration'],
|
duration = input_sockets['Duration']
|
||||||
'center': input_sockets['Center'],
|
center = input_sockets['Center']
|
||||||
'size': input_sockets['Size'],
|
size = input_sockets['Size']
|
||||||
'grid_spec': input_sockets['Grid'],
|
grid = input_sockets['Grid']
|
||||||
'medium': input_sockets['Ambient Medium'],
|
medium = input_sockets['Ambient Medium']
|
||||||
}
|
|
||||||
|
has_duration = not ct.FlowSignal.check(duration)
|
||||||
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
has_size = not ct.FlowSignal.check(size)
|
||||||
|
has_grid = not ct.FlowSignal.check(grid)
|
||||||
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
|
if has_duration and has_center and has_size and has_grid and has_medium:
|
||||||
|
return (
|
||||||
|
duration.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| grid
|
||||||
|
| medium
|
||||||
|
).compose_within(
|
||||||
|
enclosing_func=lambda els: {
|
||||||
|
'run_time': els[0],
|
||||||
|
'center': tuple(els[1].flatten()),
|
||||||
|
'size': tuple(els[2].flatten()),
|
||||||
|
'grid_spec': els[3],
|
||||||
|
'medium': els[4],
|
||||||
|
},
|
||||||
|
supports_jax=False,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Domain',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
# Loaded
|
||||||
|
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Duration': ct.FlowKind.Params,
|
||||||
|
'Center': ct.FlowKind.Params,
|
||||||
|
'Size': ct.FlowKind.Params,
|
||||||
|
'Grid': ct.FlowKind.Params,
|
||||||
|
'Ambient Medium': ct.FlowKind.Params,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
|
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
|
||||||
|
duration = input_sockets['Duration']
|
||||||
|
center = input_sockets['Center']
|
||||||
|
size = input_sockets['Size']
|
||||||
|
grid = input_sockets['Grid']
|
||||||
|
medium = input_sockets['Ambient Medium']
|
||||||
|
|
||||||
|
has_duration = not ct.FlowSignal.check(duration)
|
||||||
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
has_size = not ct.FlowSignal.check(size)
|
||||||
|
has_grid = not ct.FlowSignal.check(grid)
|
||||||
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
|
if has_duration and has_center and has_size and has_grid and has_medium:
|
||||||
|
return duration | center | size | grid | medium
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Preview
|
# - Preview
|
||||||
|
@ -100,38 +187,40 @@ class SimDomainNode(base.MaxwellSimNode):
|
||||||
props={'sim_node_name'},
|
props={'sim_node_name'},
|
||||||
)
|
)
|
||||||
def compute_previews(self, props):
|
def compute_previews(self, props):
|
||||||
|
"""Mark the managed preview object for preview when `Domain` is linked to a viewer."""
|
||||||
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
|
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
|
||||||
|
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
## Trigger
|
# Trigger
|
||||||
socket_name={'Center', 'Size'},
|
socket_name={'Center', 'Size'},
|
||||||
run_on_init=True,
|
run_on_init=True,
|
||||||
# Loaded
|
# Loaded
|
||||||
input_sockets={'Center', 'Size'},
|
input_sockets={'Center', 'Size'},
|
||||||
managed_objs={'modifier'},
|
managed_objs={'modifier'},
|
||||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
output_sockets={'Domain'},
|
||||||
scale_input_sockets={
|
output_socket_kinds={'Domain': ct.FlowKind.Params},
|
||||||
'Center': 'BlenderUnits',
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
def on_input_changed(
|
def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
|
||||||
self,
|
"""Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
|
||||||
managed_objs,
|
output_params = output_sockets['Domain']
|
||||||
input_sockets,
|
center = input_sockets['Center']
|
||||||
unit_systems,
|
|
||||||
):
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
# Push Loose Input Values to GeoNodes Modifier
|
has_center = not ct.FlowSignal.check(center)
|
||||||
managed_objs['modifier'].bl_modifier(
|
|
||||||
'NODES',
|
if has_center and has_output_params and not output_params.symbols:
|
||||||
{
|
# Push Loose Input Values to GeoNodes Modifier
|
||||||
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
|
managed_objs['modifier'].bl_modifier(
|
||||||
'unit_system': unit_systems['BlenderUnits'],
|
'NODES',
|
||||||
'inputs': {
|
{
|
||||||
'Size': input_sockets['Size'],
|
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
|
||||||
|
'unit_system': ct.UNITS_BLENDER,
|
||||||
|
'inputs': {
|
||||||
|
'Size': input_sockets['Size'],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||||
location=input_sockets['Center'],
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -71,35 +71,129 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
||||||
layout.prop(self, self.blfields['pol_axis'], expand=True)
|
layout.prop(self, self.blfields['pol_axis'], expand=True)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Outputs
|
# - FlowKind.Value
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Source',
|
'Source',
|
||||||
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
# Loaded
|
||||||
props={'pol_axis'},
|
props={'pol_axis'},
|
||||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||||
scale_input_sockets={
|
output_sockets={'Source'},
|
||||||
'Center': 'Tidy3DUnits',
|
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||||
|
)
|
||||||
|
def compute_source_value(
|
||||||
|
self, input_sockets, props, output_sockets
|
||||||
|
) -> td.PointDipole | ct.FlowSignal:
|
||||||
|
"""Compute the point dipole source, given that all inputs are non-symbolic."""
|
||||||
|
temporal_shape = input_sockets['Temporal Shape']
|
||||||
|
center = input_sockets['Center']
|
||||||
|
interpolate = input_sockets['Interpolate']
|
||||||
|
output_params = output_sockets['Source']
|
||||||
|
|
||||||
|
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||||
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_temporal_shape
|
||||||
|
and has_center
|
||||||
|
and has_interpolate
|
||||||
|
and has_output_params
|
||||||
|
and not output_params.symbols
|
||||||
|
):
|
||||||
|
pol_axis = {
|
||||||
|
ct.SimSpaceAxis.X: 'Ex',
|
||||||
|
ct.SimSpaceAxis.Y: 'Ey',
|
||||||
|
ct.SimSpaceAxis.Z: 'Ez',
|
||||||
|
}[props['pol_axis']]
|
||||||
|
## TODO: Need Hx, Hy, Hz too?
|
||||||
|
|
||||||
|
return td.PointDipole(
|
||||||
|
center=spux.convert_to_unit_system(center, ct.UNITS_TIDY3D),
|
||||||
|
source_time=temporal_shape,
|
||||||
|
interpolate=interpolate,
|
||||||
|
polarization=pol_axis,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Func
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Source',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
# Loaded
|
||||||
|
props={'pol_axis'},
|
||||||
|
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Temporal Shape': ct.FlowKind.Func,
|
||||||
|
'Center': ct.FlowKind.Func,
|
||||||
|
'Interpolate': ct.FlowKind.Func,
|
||||||
|
},
|
||||||
|
output_sockets={'Source'},
|
||||||
|
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||||
|
)
|
||||||
|
def compute_source_func(self, props, input_sockets, output_sockets) -> td.Box:
|
||||||
|
"""Compute a lazy function for the point dipole source."""
|
||||||
|
center = input_sockets['Center']
|
||||||
|
temporal_shape = input_sockets['Temporal Shape']
|
||||||
|
interpolate = input_sockets['Interpolate']
|
||||||
|
output_params = output_sockets['Source']
|
||||||
|
|
||||||
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||||
|
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if has_temporal_shape and has_center and has_interpolate and has_output_params:
|
||||||
|
pol_axis = {
|
||||||
|
ct.SimSpaceAxis.X: 'Ex',
|
||||||
|
ct.SimSpaceAxis.Y: 'Ey',
|
||||||
|
ct.SimSpaceAxis.Z: 'Ez',
|
||||||
|
}[props['pol_axis']]
|
||||||
|
## TODO: Need Hx, Hy, Hz too?
|
||||||
|
|
||||||
|
return (center | temporal_shape | interpolate).compose_within(
|
||||||
|
enclosing_func=lambda els: td.PointDipole(
|
||||||
|
center=els[0],
|
||||||
|
source_time=els[1],
|
||||||
|
interpolate=els[2],
|
||||||
|
polarization=pol_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Source',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
# Loaded
|
||||||
|
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Temporal Shape': ct.FlowKind.Params,
|
||||||
|
'Center': ct.FlowKind.Params,
|
||||||
|
'Interpolate': ct.FlowKind.Params,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_source(
|
def compute_params(
|
||||||
self,
|
self,
|
||||||
input_sockets: dict[str, typ.Any],
|
input_sockets,
|
||||||
props: dict[str, typ.Any],
|
) -> td.PointDipole | ct.FlowSignal:
|
||||||
unit_systems: dict,
|
"""Compute the point dipole source, given that all inputs are non-symbolic."""
|
||||||
) -> td.PointDipole:
|
temporal_shape = input_sockets['Temporal Shape']
|
||||||
pol_axis = {
|
center = input_sockets['Center']
|
||||||
ct.SimSpaceAxis.X: 'Ex',
|
interpolate = input_sockets['Interpolate']
|
||||||
ct.SimSpaceAxis.Y: 'Ey',
|
|
||||||
ct.SimSpaceAxis.Z: 'Ez',
|
|
||||||
}[props['pol_axis']]
|
|
||||||
|
|
||||||
return td.PointDipole(
|
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||||
center=input_sockets['Center'],
|
has_center = not ct.FlowSignal.check(center)
|
||||||
source_time=input_sockets['Temporal Shape'],
|
has_interpolate = not ct.FlowSignal.check(interpolate)
|
||||||
interpolate=input_sockets['Interpolate'],
|
|
||||||
polarization=pol_axis,
|
if has_temporal_shape and has_center and has_interpolate:
|
||||||
)
|
return temporal_shape | center | interpolate
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Preview
|
# - Preview
|
||||||
|
|
|
@ -16,15 +16,19 @@
|
||||||
|
|
||||||
"""Implements the `TemporalShapeNode`."""
|
"""Implements the `TemporalShapeNode`."""
|
||||||
|
|
||||||
|
import enum
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
import numpy as np
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import sympy.physics.units as spu
|
import sympy.physics.units as spu
|
||||||
import tidy3d as td
|
import tidy3d as td
|
||||||
|
from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
|
||||||
|
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
|
||||||
|
|
||||||
|
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger, sim_symbols
|
|
||||||
|
|
||||||
from ... import contracts as ct
|
from ... import contracts as ct
|
||||||
from ... import managed_objs, sockets
|
from ... import managed_objs, sockets
|
||||||
|
@ -33,14 +37,10 @@ from .. import base, events
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
_max_e_socket_def = sockets.ExprSocketDef(
|
# Select Default Time Unit for Envelope
|
||||||
mathtype=spux.MathType.Complex,
|
## -> Chosen to align with the default envelope_time_unit.
|
||||||
physical_type=spux.PhysicalType.EField,
|
## -> This causes it to be correct from the start.
|
||||||
default_value=1 + 0j,
|
t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
|
||||||
)
|
|
||||||
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
|
|
||||||
|
|
||||||
t_ps = sim_symbols.t(spu.picosecond)
|
|
||||||
|
|
||||||
|
|
||||||
class TemporalShapeNode(base.MaxwellSimNode):
|
class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
|
@ -63,17 +63,18 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
default_unit=spux.THz,
|
default_unit=spux.THz,
|
||||||
default_value=200,
|
default_value=200,
|
||||||
),
|
),
|
||||||
|
'max E': sockets.ExprSocketDef(
|
||||||
|
mathtype=spux.MathType.Complex,
|
||||||
|
physical_type=spux.PhysicalType.EField,
|
||||||
|
default_value=1 + 0j,
|
||||||
|
),
|
||||||
|
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
|
||||||
}
|
}
|
||||||
input_socket_sets: typ.ClassVar = {
|
input_socket_sets: typ.ClassVar = {
|
||||||
'Pulse': {
|
'Pulse': {
|
||||||
'max E': _max_e_socket_def,
|
|
||||||
'Offset Time': _offset_socket_def,
|
|
||||||
'Remove DC': sockets.BoolSocketDef(default_value=True),
|
'Remove DC': sockets.BoolSocketDef(default_value=True),
|
||||||
},
|
},
|
||||||
'Constant': {
|
'Constant': {},
|
||||||
'max E': _max_e_socket_def,
|
|
||||||
'Offset Time': _offset_socket_def,
|
|
||||||
},
|
|
||||||
'Symbolic': {
|
'Symbolic': {
|
||||||
't Range': sockets.ExprSocketDef(
|
't Range': sockets.ExprSocketDef(
|
||||||
active_kind=ct.FlowKind.Range,
|
active_kind=ct.FlowKind.Range,
|
||||||
|
@ -84,8 +85,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
default_steps=100,
|
default_steps=100,
|
||||||
),
|
),
|
||||||
'Envelope': sockets.ExprSocketDef(
|
'Envelope': sockets.ExprSocketDef(
|
||||||
default_symbols=[t_ps],
|
default_symbols=[t_def],
|
||||||
default_value=10 * t_ps.sp_symbol,
|
default_value=10 * t_def.sp_symbol,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -98,6 +99,55 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
'plot': managed_objs.ManagedBLImage,
|
'plot': managed_objs.ManagedBLImage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties
|
||||||
|
####################
|
||||||
|
active_envelope_time_unit: enum.StrEnum = bl_cache.BLField(
|
||||||
|
enum_cb=lambda self, _: self.search_time_units(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_time_units(self) -> list[ct.BLEnumElement]:
|
||||||
|
"""Compute all valid time units."""
|
||||||
|
return [
|
||||||
|
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
|
||||||
|
for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
|
||||||
|
]
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
|
||||||
|
def envelope_time_unit(self) -> spux.Unit | None:
|
||||||
|
"""Gets the current active unit for the envelope time symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current active `sympy` unit.
|
||||||
|
|
||||||
|
If the socket expression is unitless, this returns `None`.
|
||||||
|
"""
|
||||||
|
if self.active_envelope_time_unit is not None:
|
||||||
|
return spux.unit_str_to_unit(self.active_envelope_time_unit)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - UI
|
||||||
|
####################
|
||||||
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||||
|
if (
|
||||||
|
self.active_socket_set == 'Symbolic'
|
||||||
|
and self.inputs.get('Envelope')
|
||||||
|
and not self.inputs['Envelope'].is_linked
|
||||||
|
):
|
||||||
|
row = layout.row()
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Envelope Time Unit')
|
||||||
|
|
||||||
|
row = layout.row()
|
||||||
|
row.prop(
|
||||||
|
self,
|
||||||
|
self.blfields['active_envelope_time_unit'],
|
||||||
|
text='',
|
||||||
|
toggle=True,
|
||||||
|
)
|
||||||
|
|
||||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
if self.active_socket_set != 'Symbolic':
|
if self.active_socket_set != 'Symbolic':
|
||||||
box = layout.box()
|
box = layout.box()
|
||||||
|
@ -118,10 +168,53 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
col.label(text='1 / 2π·σ(𝑓)')
|
col.label(text='1 / 2π·σ(𝑓)')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKind: Value
|
# - Events
|
||||||
|
####################
|
||||||
|
@events.on_value_changed(
|
||||||
|
# Trigger
|
||||||
|
prop_name={'active_socket_set', 'envelope_time_unit'},
|
||||||
|
# Loaded
|
||||||
|
props={'active_socket_set', 'envelope_time_unit'},
|
||||||
|
)
|
||||||
|
def on_envelope_time_unit_changed(self, props) -> None:
|
||||||
|
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
|
||||||
|
active_socket_set = props['active_socket_set']
|
||||||
|
envelope_time_unit = props['envelope_time_unit']
|
||||||
|
if active_socket_set == 'Symbolic':
|
||||||
|
bl_socket = self.inputs['Envelope']
|
||||||
|
wanted_t_sym = sim_symbols.t(envelope_time_unit)
|
||||||
|
|
||||||
|
if not bl_socket.symbols or bl_socket.symbols[0] != wanted_t_sym:
|
||||||
|
bl_socket.symbols = [wanted_t_sym]
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Value
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Temporal Shape',
|
'Temporal Shape',
|
||||||
|
kind=ct.FlowKind.Value,
|
||||||
|
# Loaded
|
||||||
|
output_sockets={'Temporal Shape'},
|
||||||
|
output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||||
|
)
|
||||||
|
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
|
"""Compute a single temporal shape."""
|
||||||
|
output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
|
||||||
|
output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
has_output_func = not ct.FlowSignal.check(output_func)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if has_output_func and has_output_params and not output_params.symbols:
|
||||||
|
return output_func.realize(output_params)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind: Func
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Temporal Shape',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
# Loaded
|
# Loaded
|
||||||
props={'active_socket_set'},
|
props={'active_socket_set'},
|
||||||
input_sockets={
|
input_sockets={
|
||||||
|
@ -134,60 +227,178 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
'Envelope',
|
'Envelope',
|
||||||
},
|
},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
't Range': ct.FlowKind.Range,
|
'max E': ct.FlowKind.Func,
|
||||||
'Envelope': ct.FlowKind.Func,
|
'μ Freq': ct.FlowKind.Func,
|
||||||
},
|
'σ Freq': ct.FlowKind.Func,
|
||||||
input_sockets_optional={
|
'Offset Time': ct.FlowKind.Func,
|
||||||
'max E': True,
|
'Remove DC': ct.FlowKind.Value,
|
||||||
'Offset Time': True,
|
't Range': ct.FlowKind.Func,
|
||||||
'Remove DC': True,
|
'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
|
||||||
't Range': True,
|
|
||||||
'Envelope': True,
|
|
||||||
},
|
|
||||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
|
||||||
scale_input_sockets={
|
|
||||||
'max E': 'Tidy3DUnits',
|
|
||||||
'μ Freq': 'Tidy3DUnits',
|
|
||||||
'σ Freq': 'Tidy3DUnits',
|
|
||||||
't Range': 'Tidy3DUnits',
|
|
||||||
'Offset Time': 'Tidy3DUnits',
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_temporal_shape(
|
def compute_temporal_shape_func(
|
||||||
self, props, input_sockets, unit_systems
|
self,
|
||||||
|
props,
|
||||||
|
input_sockets,
|
||||||
) -> td.GaussianPulse:
|
) -> td.GaussianPulse:
|
||||||
match props['active_socket_set']:
|
"""Compute a single temporal shape from non-parameterized inputs."""
|
||||||
case 'Pulse':
|
mean_freq = input_sockets['μ Freq']
|
||||||
return td.GaussianPulse(
|
std_freq = input_sockets['σ Freq']
|
||||||
amplitude=sp.re(input_sockets['max E']),
|
max_e = input_sockets['max E']
|
||||||
phase=sp.im(input_sockets['max E']),
|
offset = input_sockets['Offset Time']
|
||||||
freq0=input_sockets['μ Freq'],
|
|
||||||
fwidth=input_sockets['σ Freq'],
|
|
||||||
offset=input_sockets['Offset Time'],
|
|
||||||
remove_dc_component=input_sockets['Remove DC'],
|
|
||||||
)
|
|
||||||
|
|
||||||
case 'Constant':
|
has_mean_freq = not ct.FlowSignal.check(mean_freq)
|
||||||
return td.ContinuousWave(
|
has_std_freq = not ct.FlowSignal.check(std_freq)
|
||||||
amplitude=sp.re(input_sockets['max E']),
|
has_max_e = not ct.FlowSignal.check(max_e)
|
||||||
phase=sp.im(input_sockets['max E']),
|
has_offset = not ct.FlowSignal.check(offset)
|
||||||
freq0=input_sockets['μ Freq'],
|
|
||||||
fwidth=input_sockets['σ Freq'],
|
|
||||||
offset=input_sockets['Offset Time'],
|
|
||||||
)
|
|
||||||
|
|
||||||
case 'Symbolic':
|
if has_mean_freq and has_std_freq and has_max_e and has_offset:
|
||||||
lzrange = input_sockets['t Range']
|
common_func = (
|
||||||
envelope_ps = input_sockets['Envelope'].func_jax
|
max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| offset ## Already unitless
|
||||||
|
)
|
||||||
|
match props['active_socket_set']:
|
||||||
|
case 'Pulse':
|
||||||
|
remove_dc = input_sockets['Remove DC']
|
||||||
|
|
||||||
return td.CustomSourceTime.from_values(
|
has_remove_dc = not ct.FlowSignal.check(remove_dc)
|
||||||
freq0=input_sockets['μ Freq'],
|
|
||||||
fwidth=input_sockets['σ Freq'],
|
if has_remove_dc:
|
||||||
values=envelope_ps(
|
return common_func.compose_within(
|
||||||
lzrange.rescale_to_unit(spu.ps).realize_array.values
|
lambda els: td.GaussianPulse(
|
||||||
),
|
amplitude=complex(els[0]).real,
|
||||||
dt=input_sockets['t Range'].realize_step_size(),
|
phase=complex(els[0]).imag,
|
||||||
)
|
freq0=els[1],
|
||||||
|
fwidth=els[2],
|
||||||
|
offset=els[3],
|
||||||
|
remove_dc_component=remove_dc,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case 'Constant':
|
||||||
|
return common_func.compose_within(
|
||||||
|
lambda els: td.GaussianPulse(
|
||||||
|
amplitude=complex(els[0]).real,
|
||||||
|
phase=complex(els[0]).imag,
|
||||||
|
freq0=els[1],
|
||||||
|
fwidth=els[2],
|
||||||
|
offset=els[3],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case 'Symbolic':
|
||||||
|
t_range = input_sockets['t Range']
|
||||||
|
envelope = input_sockets['Envelope'][ct.FlowKind.Func]
|
||||||
|
envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
has_t_range = not ct.FlowSignal.check(t_range)
|
||||||
|
has_envelope = not ct.FlowSignal.check(envelope)
|
||||||
|
has_envelope_params = not ct.FlowSignal.check(envelope_params)
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_t_range
|
||||||
|
and has_envelope
|
||||||
|
and has_envelope_params
|
||||||
|
and len(envelope_params.symbols) == 1
|
||||||
|
## TODO: Allow unrealized envelope symbols
|
||||||
|
and any(
|
||||||
|
sym.physical_type is spux.PhysicalType.Time
|
||||||
|
for sym in envelope_params.symbols
|
||||||
|
)
|
||||||
|
):
|
||||||
|
envelope_time_unit = next(
|
||||||
|
sym.unit
|
||||||
|
for sym in envelope_params.symbols
|
||||||
|
if sym.physical_type is spux.PhysicalType.Time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deduce Partially Realized Envelope Function
|
||||||
|
## -> We need a pure-numerical function w/pre-realized stuff baked in.
|
||||||
|
## -> 'realize_partial' does this for us.
|
||||||
|
envelope_realizer = envelope.realize_partial(envelope_params)
|
||||||
|
|
||||||
|
# Compose w/Envelope Function
|
||||||
|
## -> First, the numerical time values must be converted.
|
||||||
|
## -> This ensures that the raw array is compatible w/the envelope.
|
||||||
|
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
|
||||||
|
## -> Because of the checks, we've guaranteed that all this is correct.
|
||||||
|
return (
|
||||||
|
common_func ## 1 | freq0, 2 | fwidth, 3 | offset
|
||||||
|
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
|
||||||
|
| t_range.scale_to_unit(envelope_time_unit).compose_within(
|
||||||
|
lambda t: envelope_realizer(t)
|
||||||
|
) ## 5
|
||||||
|
).compose_within(
|
||||||
|
lambda els: td.CustomSourceTime(
|
||||||
|
amplitude=complex(els[0]).real,
|
||||||
|
phase=complex(els[0]).imag,
|
||||||
|
freq0=els[1],
|
||||||
|
fwidth=els[2],
|
||||||
|
offset=els[3],
|
||||||
|
source_time_dataset=td_TimeDataset(
|
||||||
|
values=td_TimeDataArray(
|
||||||
|
els[5], coords={'t': np.array(els[4])}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind: Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Temporal Shape',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
# Loaded
|
||||||
|
props={'active_socket_set', 'envelope_time_unit'},
|
||||||
|
input_sockets={
|
||||||
|
'max E',
|
||||||
|
'μ Freq',
|
||||||
|
'σ Freq',
|
||||||
|
'Offset Time',
|
||||||
|
't Range',
|
||||||
|
},
|
||||||
|
input_socket_kinds={
|
||||||
|
'max E': ct.FlowKind.Params,
|
||||||
|
'μ Freq': ct.FlowKind.Params,
|
||||||
|
'σ Freq': ct.FlowKind.Params,
|
||||||
|
'Offset Time': ct.FlowKind.Params,
|
||||||
|
't Range': ct.FlowKind.Params,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_temporal_shape_params(
|
||||||
|
self,
|
||||||
|
props,
|
||||||
|
input_sockets,
|
||||||
|
) -> td.GaussianPulse:
|
||||||
|
"""Compute a single temporal shape from non-parameterized inputs."""
|
||||||
|
mean_freq = input_sockets['μ Freq']
|
||||||
|
std_freq = input_sockets['σ Freq']
|
||||||
|
max_e = input_sockets['max E']
|
||||||
|
offset = input_sockets['Offset Time']
|
||||||
|
|
||||||
|
has_mean_freq = not ct.FlowSignal.check(mean_freq)
|
||||||
|
has_std_freq = not ct.FlowSignal.check(std_freq)
|
||||||
|
has_max_e = not ct.FlowSignal.check(max_e)
|
||||||
|
has_offset = not ct.FlowSignal.check(offset)
|
||||||
|
|
||||||
|
if has_mean_freq and has_std_freq and has_max_e and has_offset:
|
||||||
|
common_params = max_e | mean_freq | std_freq | offset
|
||||||
|
match props['active_socket_set']:
|
||||||
|
case 'Pulse' | 'Constant':
|
||||||
|
return common_params
|
||||||
|
|
||||||
|
case 'Symbolic':
|
||||||
|
t_range = input_sockets['t Range']
|
||||||
|
has_t_range = not ct.FlowSignal.check(t_range)
|
||||||
|
|
||||||
|
if has_t_range:
|
||||||
|
return common_params | t_range | t_range
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -88,28 +88,27 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'Structure',
|
'Structure',
|
||||||
kind=ct.FlowKind.Value,
|
kind=ct.FlowKind.Value,
|
||||||
# Loaded
|
# Loaded
|
||||||
props={'differentiable'},
|
|
||||||
input_sockets={'Medium', 'Center', 'Size'},
|
input_sockets={'Medium', 'Center', 'Size'},
|
||||||
output_sockets={'Structure'},
|
output_sockets={'Structure'},
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||||
)
|
)
|
||||||
def compute_value(self, props, input_sockets, output_sockets) -> td.Box:
|
def compute_value(self, input_sockets, output_sockets) -> td.Box:
|
||||||
output_params = output_sockets['Structure']
|
"""Compute a single box structure object, given that all inputs are non-symbolic."""
|
||||||
center = input_sockets['Center']
|
center = input_sockets['Center']
|
||||||
size = input_sockets['Size']
|
size = input_sockets['Size']
|
||||||
medium = input_sockets['Medium']
|
medium = input_sockets['Medium']
|
||||||
|
output_params = output_sockets['Structure']
|
||||||
|
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
|
||||||
has_center = not ct.FlowSignal.check(center)
|
has_center = not ct.FlowSignal.check(center)
|
||||||
has_size = not ct.FlowSignal.check(size)
|
has_size = not ct.FlowSignal.check(size)
|
||||||
has_medium = not ct.FlowSignal.check(medium)
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
has_center
|
has_center
|
||||||
and has_size
|
and has_size
|
||||||
and has_medium
|
and has_medium
|
||||||
and has_output_params
|
and has_output_params
|
||||||
and not props['differentiable']
|
|
||||||
and not output_params.symbols
|
and not output_params.symbols
|
||||||
):
|
):
|
||||||
return td.Structure(
|
return td.Structure(
|
||||||
|
@ -138,7 +137,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
output_sockets={'Structure'},
|
output_sockets={'Structure'},
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||||
)
|
)
|
||||||
def compute_lazy_structure(self, props, input_sockets, output_sockets) -> td.Box:
|
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
|
||||||
|
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
|
||||||
output_params = output_sockets['Structure']
|
output_params = output_sockets['Structure']
|
||||||
center = input_sockets['Center']
|
center = input_sockets['Center']
|
||||||
size = input_sockets['Size']
|
size = input_sockets['Size']
|
||||||
|
@ -149,14 +149,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
has_size = not ct.FlowSignal.check(size)
|
has_size = not ct.FlowSignal.check(size)
|
||||||
has_medium = not ct.FlowSignal.check(medium)
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
differentiable = props['differentiable']
|
if has_output_params and has_center and has_size and has_medium:
|
||||||
if (
|
differentiable = props['differentiable']
|
||||||
has_output_params
|
|
||||||
and has_center
|
|
||||||
and has_size
|
|
||||||
and has_medium
|
|
||||||
and differentiable == output_params.is_differentiable
|
|
||||||
):
|
|
||||||
if differentiable:
|
if differentiable:
|
||||||
return (center | size | medium).compose_within(
|
return (center | size | medium).compose_within(
|
||||||
enclosing_func=lambda els: tdadj.JaxStructure(
|
enclosing_func=lambda els: tdadj.JaxStructure(
|
||||||
|
@ -169,6 +163,12 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return (center | size | medium).compose_within(
|
return (center | size | medium).compose_within(
|
||||||
|
## TODO: Unit conversion within the composed function??
|
||||||
|
## -- We do need Tidy3D to be given ex. micrometers in particular.
|
||||||
|
## -- But the previous numerical output might not be micrometers.
|
||||||
|
## -- There must be a way to add a conversion in, without strangeness.
|
||||||
|
## -- Ex. can compose_within() take a unit system?
|
||||||
|
## -- This would require
|
||||||
enclosing_func=lambda els: td.Structure(
|
enclosing_func=lambda els: td.Structure(
|
||||||
geometry=td.Box(
|
geometry=td.Box(
|
||||||
center=tuple(els[0].flatten()),
|
center=tuple(els[0].flatten()),
|
||||||
|
@ -205,13 +205,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
has_medium = not ct.FlowSignal.check(medium)
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
if has_center and has_size and has_medium:
|
if has_center and has_size and has_medium:
|
||||||
if props['differentiable'] == (
|
return center | size | medium
|
||||||
center.is_differentiable
|
|
||||||
and size.is_differentiable
|
|
||||||
and medium.is_differentiable
|
|
||||||
):
|
|
||||||
return center | size | medium
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -226,6 +220,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||||
)
|
)
|
||||||
def compute_previews(self, props, output_sockets):
|
def compute_previews(self, props, output_sockets):
|
||||||
|
"""Mark the managed preview object when recursively linked to a viewer."""
|
||||||
output_params = output_sockets['Structure']
|
output_params = output_sockets['Structure']
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
@ -245,10 +240,14 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
|
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
|
||||||
output_params = output_sockets['Structure']
|
output_params = output_sockets['Structure']
|
||||||
|
center = input_sockets['Center']
|
||||||
|
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
if has_output_params and not output_params.symbols:
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
if has_center and has_output_params and not output_params.symbols:
|
||||||
|
## TODO: There are strategies for handling examples of symbol values.
|
||||||
|
|
||||||
# Push Loose Input Values to GeoNodes Modifier
|
# Push Loose Input Values to GeoNodes Modifier
|
||||||
center = input_sockets['Center']
|
|
||||||
managed_objs['modifier'].bl_modifier(
|
managed_objs['modifier'].bl_modifier(
|
||||||
'NODES',
|
'NODES',
|
||||||
{
|
{
|
||||||
|
|
|
@ -43,17 +43,28 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
socket_type: ct.SocketType
|
socket_type: ct.SocketType
|
||||||
|
active_kind: typ.Literal[
|
||||||
|
ct.FlowKind.Value,
|
||||||
|
ct.FlowKind.Array,
|
||||||
|
ct.FlowKind.Range,
|
||||||
|
ct.FlowKind.Func,
|
||||||
|
] = ct.FlowKind.Value
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Socket Interaction
|
||||||
|
####################
|
||||||
def preinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
def preinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||||
"""Pre-initialize a real Blender node socket from this socket definition.
|
"""Pre-initialize a real Blender node socket from this socket definition.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||||
"""
|
"""
|
||||||
log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
|
# log.debug('%s: Start Socket Preinit', bl_socket.bl_label)
|
||||||
bl_socket.reset_instance_id()
|
bl_socket.reset_instance_id()
|
||||||
bl_socket.regenerate_dynamic_field_persistance()
|
bl_socket.regenerate_dynamic_field_persistance()
|
||||||
log.debug('%s: End Socket Preinit', bl_socket.bl_label)
|
|
||||||
|
bl_socket.active_kind = self.active_kind
|
||||||
|
# log.debug('%s: End Socket Preinit', bl_socket.bl_label)
|
||||||
|
|
||||||
def postinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
def postinit(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||||
"""Pre-initialize a real Blender node socket from this socket definition.
|
"""Pre-initialize a real Blender node socket from this socket definition.
|
||||||
|
@ -61,12 +72,12 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
||||||
Parameters:
|
Parameters:
|
||||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||||
"""
|
"""
|
||||||
log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
|
# log.debug('%s: Start Socket Postinit', bl_socket.bl_label)
|
||||||
bl_socket.is_initializing = False
|
bl_socket.is_initializing = False
|
||||||
bl_socket.on_active_kind_changed()
|
bl_socket.on_active_kind_changed()
|
||||||
bl_socket.on_socket_props_changed(set(bl_socket.blfields))
|
bl_socket.on_socket_props_changed(set(bl_socket.blfields))
|
||||||
bl_socket.on_data_changed(set(ct.FlowKind))
|
bl_socket.on_data_changed(set(ct.FlowKind))
|
||||||
log.debug('%s: End Socket Postinit', bl_socket.bl_label)
|
# log.debug('%s: End Socket Postinit', bl_socket.bl_label)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
|
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||||
|
@ -76,6 +87,43 @@ class SocketDef(pyd.BaseModel, abc.ABC):
|
||||||
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Comparison
|
||||||
|
####################
|
||||||
|
def compare(self, bl_socket: bpy.types.NodeSocket) -> bool:
|
||||||
|
"""Whether this `SocketDef` can be considered to uniquely define the given `bl_socket`.
|
||||||
|
|
||||||
|
The general criteria for "uniquely defines" is whether **the same `bl_socket`** could be created using this `SocketDef`.
|
||||||
|
The extent to which user-altered properties are considered in this regard is a matter of taste, encapsulated entirely within `self.local_compare()`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Used when determining whether to replace sockets with newer variants when synchronizing changes.
|
||||||
|
|
||||||
|
**NOTE**: Removing/replacing loose input sockets
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
bl_socket.socket_type is self.socket_type
|
||||||
|
and bl_socket.active_kind is self.active_kind
|
||||||
|
and self.local_compare(bl_socket)
|
||||||
|
)
|
||||||
|
|
||||||
|
def local_compare(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||||
|
"""Compare this `SocketDef` to an established `bl_socket` in a manner specific to the node.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Run by `self.compare()`.
|
||||||
|
Optionally overriden by individual sockets.
|
||||||
|
|
||||||
|
When not overridden, it will always return `False`, indicating that the socket is _never_ uniquely defined by this `SocketDef`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
bl_socket: The Blender node socket to alter using data from this SocketDef.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Serialization
|
# - Serialization
|
||||||
####################
|
####################
|
||||||
|
@ -426,8 +474,34 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
||||||
Parameters:
|
Parameters:
|
||||||
socket_kinds: The altered `ct.FlowKind`s flowing through.
|
socket_kinds: The altered `ct.FlowKind`s flowing through.
|
||||||
"""
|
"""
|
||||||
|
# Run Socket Callbacks
|
||||||
self.on_socket_data_changed(socket_kinds)
|
self.on_socket_data_changed(socket_kinds)
|
||||||
|
|
||||||
|
# Mark Active FlowKind Links as Invalid
|
||||||
|
## -> Mark link as invalid (very red) if a FlowSignal is traveling.
|
||||||
|
## -> This helps explain why whatever isn't working isn't working.
|
||||||
|
## -> TODO: We need a different approach.
|
||||||
|
# log.debug(
|
||||||
|
# '[%s] Checking FlowKind Validity (socket_kinds=%s)',
|
||||||
|
# self.name,
|
||||||
|
# str(socket_kinds),
|
||||||
|
# )
|
||||||
|
# if self.is_linked and not self.is_output:
|
||||||
|
# link = self.links[0]
|
||||||
|
# linked_flow = self.compute_data(kind=self.active_kind)
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# link.is_valid
|
||||||
|
# and self.active_kind in socket_kinds
|
||||||
|
# and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
|
||||||
|
# ):
|
||||||
|
# node_tree = self.id_data
|
||||||
|
# node_tree.report_link_validity(link, False)
|
||||||
|
|
||||||
|
# elif not link.is_valid:
|
||||||
|
# node_tree = self.id_data
|
||||||
|
# node_tree.report_link_validity(link, True)
|
||||||
|
|
||||||
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
|
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
|
||||||
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
|
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
|
||||||
|
|
||||||
|
@ -479,7 +553,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
||||||
The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows.
|
The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows.
|
||||||
"""
|
"""
|
||||||
# log.debug(
|
# log.debug(
|
||||||
# '[%s] [%s] Triggered (socket_kinds=%s)',
|
# '[%s] [%s] Socket Triggered (socket_kinds=%s)',
|
||||||
# self.name,
|
# self.name,
|
||||||
# event,
|
# event,
|
||||||
# str(socket_kinds),
|
# str(socket_kinds),
|
||||||
|
@ -757,7 +831,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
||||||
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
|
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
|
||||||
|
|
||||||
# Return Single Value / List of Values
|
# Return Single Value / List of Values
|
||||||
## -> Multi-input sockets are not yet supported.
|
## -> Multi-input sockets are not (yet) supported.
|
||||||
if linked_values:
|
if linked_values:
|
||||||
return linked_values[0]
|
return linked_values[0]
|
||||||
|
|
||||||
|
@ -891,10 +965,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
||||||
# FlowKind Draw Row
|
# FlowKind Draw Row
|
||||||
col = row.column(align=True)
|
col = row.column(align=True)
|
||||||
{
|
{
|
||||||
|
ct.FlowKind.Capabilities: lambda *_: None,
|
||||||
|
ct.FlowKind.Previews: lambda *_: None,
|
||||||
ct.FlowKind.Value: self.draw_value,
|
ct.FlowKind.Value: self.draw_value,
|
||||||
ct.FlowKind.Array: self.draw_array,
|
ct.FlowKind.Array: self.draw_array,
|
||||||
ct.FlowKind.Range: self.draw_lazy_range,
|
ct.FlowKind.Range: self.draw_lazy_range,
|
||||||
ct.FlowKind.Func: self.draw_lazy_func,
|
ct.FlowKind.Func: self.draw_lazy_func,
|
||||||
|
ct.FlowKind.Params: lambda *_: None,
|
||||||
|
ct.FlowKind.Info: lambda *_: None,
|
||||||
}[self.active_kind](col)
|
}[self.active_kind](col)
|
||||||
|
|
||||||
# Info Drawing
|
# Info Drawing
|
||||||
|
|
|
@ -51,6 +51,16 @@ class BoolBLSocket(base.MaxwellSimSocket):
|
||||||
def value(self, value: bool) -> None:
|
def value(self, value: bool) -> None:
|
||||||
self.raw_value = value
|
self.raw_value = value
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||||
|
def lazy_func(self) -> ct.FuncFlow:
|
||||||
|
return ct.FuncFlow(
|
||||||
|
func=lambda: self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property()
|
||||||
|
def params(self) -> ct.FuncFlow:
|
||||||
|
return ct.ParamsFlow()
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Socket Configuration
|
# - Socket Configuration
|
||||||
|
|
|
@ -130,6 +130,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
'physical_type',
|
'physical_type',
|
||||||
'unit',
|
'unit',
|
||||||
'size',
|
'size',
|
||||||
|
'value',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def output_sym(self) -> sim_symbols.SimSymbol | None:
|
def output_sym(self) -> sim_symbols.SimSymbol | None:
|
||||||
|
@ -140,13 +141,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
|
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
|
||||||
"""
|
"""
|
||||||
if self.symbols:
|
match self.active_kind:
|
||||||
if self.active_kind in [ct.FlowKind.Value, ct.FlowKind.Func]:
|
case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
|
||||||
return self._parse_expr_symbol(
|
return self._parse_expr_symbol(
|
||||||
self._parse_expr_str(self.raw_value_spstr)
|
self._parse_expr_str(self.raw_value_spstr)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.active_kind is ct.FlowKind.Range:
|
case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
|
||||||
|
return sim_symbols.SimSymbol(
|
||||||
|
sym_name=self.output_name,
|
||||||
|
mathtype=self.mathtype,
|
||||||
|
physical_type=self.physical_type,
|
||||||
|
unit=self.unit,
|
||||||
|
rows=self.size.rows,
|
||||||
|
cols=self.size.cols,
|
||||||
|
exclude_zero=(
|
||||||
|
not self.value.is_zero
|
||||||
|
if self.value.is_zero is not None
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
## TODO: Does this work for matrix elements?
|
||||||
|
)
|
||||||
|
|
||||||
|
case ct.FlowKind.Range if self.symbols:
|
||||||
## TODO: Support RangeFlow
|
## TODO: Support RangeFlow
|
||||||
## -- It's hard; we need a min-span set over bound domains.
|
## -- It's hard; we need a min-span set over bound domains.
|
||||||
## -- We... Don't use this anywhere. Yet?
|
## -- We... Don't use this anywhere. Yet?
|
||||||
|
@ -159,20 +176,37 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
|
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
raise NotImplementedError
|
case ct.FlowKind.Range if not self.symbols:
|
||||||
|
return sim_symbols.SimSymbol(
|
||||||
|
sym_name=self.output_name,
|
||||||
|
mathtype=self.mathtype,
|
||||||
|
physical_type=self.physical_type,
|
||||||
|
unit=self.unit,
|
||||||
|
rows=self.lazy_range.steps,
|
||||||
|
cols=1,
|
||||||
|
exclude_zero=not self.lazy_range.is_always_nonzero,
|
||||||
|
)
|
||||||
|
|
||||||
return sim_symbols.SimSymbol(
|
####################
|
||||||
sym_name=self.output_name,
|
# - Value|Range Swapper
|
||||||
mathtype=self.mathtype,
|
####################
|
||||||
physical_type=self.physical_type,
|
use_value_range_swapper: bool = bl_cache.BLField(False)
|
||||||
unit=self.unit,
|
selected_value_range: ct.FlowKind = bl_cache.BLField(
|
||||||
rows=self.size.rows,
|
enum_cb=lambda self, _: self._value_or_range(),
|
||||||
cols=self.size.cols,
|
)
|
||||||
)
|
|
||||||
|
def _value_or_range(self):
|
||||||
|
return [
|
||||||
|
flow_kind.bl_enum_element(i)
|
||||||
|
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
|
||||||
|
]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Symbols
|
# - Symbols
|
||||||
####################
|
####################
|
||||||
|
lazy_range_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
|
sim_symbols.SimSymbolName.Expr
|
||||||
|
)
|
||||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.Expr
|
sim_symbols.SimSymbolName.Expr
|
||||||
)
|
)
|
||||||
|
@ -343,7 +377,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
See `MaxwellSimTree` for more detail on the link callbacks.
|
See `MaxwellSimTree` for more detail on the link callbacks.
|
||||||
"""
|
"""
|
||||||
## NODE: Depends on suppressed on_prop_changed
|
## NOTE: Depends on suppressed on_prop_changed
|
||||||
|
|
||||||
if ct.FlowKind.Info in socket_kinds:
|
if ct.FlowKind.Info in socket_kinds:
|
||||||
info = self.compute_data(kind=ct.FlowKind.Info)
|
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||||
|
@ -371,7 +405,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
See `MaxwellSimTree` for more detail on the link callbacks.
|
See `MaxwellSimTree` for more detail on the link callbacks.
|
||||||
"""
|
"""
|
||||||
## NODE: Depends on suppressed on_prop_changed
|
## NOTE: Depends on suppressed on_prop_changed
|
||||||
|
if ('selected_value_range', 'invalidate') in cleared_blfields:
|
||||||
|
self.active_kind = self.selected_value_range
|
||||||
|
self.on_active_kind_changed()
|
||||||
|
|
||||||
# Conditional Unit-Conversion
|
# Conditional Unit-Conversion
|
||||||
## -> This is niche functionality, but the only way to convert units.
|
## -> This is niche functionality, but the only way to convert units.
|
||||||
|
@ -757,7 +794,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
@bl_cache.cached_bl_property(
|
@bl_cache.cached_bl_property(
|
||||||
depends_on={
|
depends_on={
|
||||||
'value',
|
'value',
|
||||||
'symbols',
|
|
||||||
'sorted_sp_symbols',
|
'sorted_sp_symbols',
|
||||||
'sorted_symbols',
|
'sorted_symbols',
|
||||||
'output_sym',
|
'output_sym',
|
||||||
|
@ -769,82 +805,87 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
|
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
|
||||||
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
||||||
"""
|
"""
|
||||||
# Symbolic
|
if self.output_sym is not None:
|
||||||
## -> `self.value` is guaranteed to be an expression with unknowns.
|
match self.active_kind:
|
||||||
## -> The function computes `self.value` with unknowns as arguments.
|
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||||
if self.symbols:
|
self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||||
value = self.value
|
):
|
||||||
has_value = not ct.FlowSignal.check(value)
|
return ct.FuncFlow(
|
||||||
|
func=sp.lambdify(
|
||||||
|
self.sorted_sp_symbols,
|
||||||
|
self.output_sym.conform(self.value, strip_unit=True),
|
||||||
|
'jax',
|
||||||
|
),
|
||||||
|
func_args=list(self.sorted_symbols),
|
||||||
|
func_output=self.output_sym,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
output_sym = self.output_sym
|
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
|
||||||
if output_sym is not None and has_value:
|
return ct.FuncFlow(
|
||||||
return ct.FuncFlow(
|
func=lambda v: v,
|
||||||
func=sp.lambdify(
|
func_args=[self.output_sym],
|
||||||
self.sorted_sp_symbols,
|
func_output=self.output_sym,
|
||||||
output_sym.conform(value, strip_unit=True),
|
supports_jax=True,
|
||||||
'jax',
|
)
|
||||||
),
|
|
||||||
func_args=list(self.sorted_symbols),
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
|
|
||||||
# Constant
|
case ct.FlowKind.Range if self.sorted_symbols:
|
||||||
## -> When a `self.value` has no unknowns, use a dummy function.
|
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||||
## -> ("Dummy" as in returns the same argument that it takes).
|
raise NotImplementedError(msg)
|
||||||
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
|
|
||||||
## -> Generally only useful for operations with other expressions.
|
|
||||||
return ct.FuncFlow(
|
|
||||||
func=lambda v: v,
|
|
||||||
func_args=[self.output_sym],
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols'})
|
case ct.FlowKind.Range if (
|
||||||
def is_differentiable(self) -> bool:
|
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||||
"""Whether all symbols are differentiable.
|
):
|
||||||
|
return ct.FuncFlow(
|
||||||
|
func=lambda v: v,
|
||||||
|
func_args=[self.output_sym],
|
||||||
|
func_output=self.output_sym,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
If there are no symbols, then there is nothing to differentiate, and thus the expression is differentiable.
|
return ct.FlowSignal.FlowPending
|
||||||
"""
|
|
||||||
if not self.sorted_symbols:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return all(
|
@bl_cache.cached_bl_property(
|
||||||
sym.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
depends_on={'sorted_symbols', 'output_sym', 'value', 'lazy_range'}
|
||||||
for sym in self.sorted_symbols
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym', 'value'})
|
|
||||||
def params(self) -> ct.ParamsFlow:
|
def params(self) -> ct.ParamsFlow:
|
||||||
"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||||
|
|
||||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
||||||
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
||||||
"""
|
"""
|
||||||
# Symbolic
|
output_sym = self.output_sym
|
||||||
## -> The Expr socket does not declare actual values for the symbols.
|
if output_sym is not None:
|
||||||
## -> They should be realized later, ex. in a Viz node.
|
match self.active_kind:
|
||||||
## -> Therefore, we just dump the symbols. Easy!
|
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||||
## -> NOTE: func_args must have the same symbol order as was lambdified.
|
return ct.ParamsFlow(
|
||||||
if self.sorted_symbols:
|
arg_targets=list(self.sorted_symbols),
|
||||||
output_sym = self.output_sym
|
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
|
||||||
if output_sym is not None:
|
symbols=set(self.sorted_symbols),
|
||||||
return ct.ParamsFlow(
|
)
|
||||||
arg_targets=list(self.sorted_symbols),
|
|
||||||
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
|
|
||||||
symbols=self.sorted_symbols,
|
|
||||||
is_differentiable=self.is_differentiable,
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
|
|
||||||
# Constant
|
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||||
## -> Simply pass self.value verbatim as a function argument.
|
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||||
## -> Easy dice, easy life!
|
):
|
||||||
return ct.ParamsFlow(
|
return ct.ParamsFlow(
|
||||||
arg_targets=[self.output_sym],
|
arg_targets=[self.output_sym],
|
||||||
func_args=[self.value],
|
func_args=[self.value],
|
||||||
is_differentiable=self.is_differentiable,
|
)
|
||||||
)
|
|
||||||
|
case ct.FlowKind.Range if self.sorted_symbols:
|
||||||
|
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
case ct.FlowKind.Range if (
|
||||||
|
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||||
|
):
|
||||||
|
return ct.ParamsFlow(
|
||||||
|
arg_targets=[self.output_sym],
|
||||||
|
func_args=[self.output_sym.sp_symbol_matsym],
|
||||||
|
symbols={self.output_sym},
|
||||||
|
).realize_partial({self.output_sym: self.lazy_range})
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
|
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
|
||||||
def info(self) -> ct.InfoFlow:
|
def info(self) -> ct.InfoFlow:
|
||||||
|
@ -858,21 +899,33 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||||
"""
|
"""
|
||||||
# Constant
|
output_sym = self.output_sym
|
||||||
## -> The input SimSymbols become continuous dimensional indices.
|
if output_sym is not None:
|
||||||
## -> All domain validity information is defined on the SimSymbol keys.
|
match self.active_kind:
|
||||||
if self.sorted_symbols:
|
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||||
output_sym = self.output_sym
|
return ct.InfoFlow(
|
||||||
if output_sym is not None:
|
dims={sym: None for sym in self.sorted_symbols},
|
||||||
return ct.InfoFlow(
|
output=self.output_sym,
|
||||||
dims={sym: None for sym in self.sorted_symbols},
|
)
|
||||||
output=self.output_sym,
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
|
|
||||||
# Constant
|
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||||
## -> We only need the output symbol to describe the raw data.
|
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||||
return ct.InfoFlow(output=self.output_sym)
|
):
|
||||||
|
return ct.InfoFlow(output=self.output_sym)
|
||||||
|
|
||||||
|
case ct.FlowKind.Range if self.sorted_symbols:
|
||||||
|
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
case ct.FlowKind.Range if (
|
||||||
|
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||||
|
):
|
||||||
|
return ct.InfoFlow(
|
||||||
|
dims={self.output_sym: self.lazy_range},
|
||||||
|
output=self.output_sym.update(rows=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKind: Capabilities
|
# - FlowKind: Capabilities
|
||||||
|
@ -1039,6 +1092,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI.
|
However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI.
|
||||||
"""
|
"""
|
||||||
|
if self.use_value_range_swapper:
|
||||||
|
col.prop(self, self.blfields['selected_value_range'], text='')
|
||||||
|
|
||||||
if self.symbols:
|
if self.symbols:
|
||||||
col.prop(self, self.blfields['raw_value_spstr'], text='')
|
col.prop(self, self.blfields['raw_value_spstr'], text='')
|
||||||
|
|
||||||
|
@ -1097,6 +1153,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
|
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
|
||||||
As such, `self.steps` won't be exposed in the UI.
|
As such, `self.steps` won't be exposed in the UI.
|
||||||
"""
|
"""
|
||||||
|
if self.use_value_range_swapper:
|
||||||
|
col.prop(self, self.blfields['selected_value_range'], text='')
|
||||||
|
|
||||||
if self.symbols:
|
if self.symbols:
|
||||||
col.prop(self, self.blfields['raw_min_spstr'], text='')
|
col.prop(self, self.blfields['raw_min_spstr'], text='')
|
||||||
col.prop(self, self.blfields['raw_max_spstr'], text='')
|
col.prop(self, self.blfields['raw_max_spstr'], text='')
|
||||||
|
@ -1198,13 +1257,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
# - Socket Configuration
|
# - Socket Configuration
|
||||||
####################
|
####################
|
||||||
class ExprSocketDef(base.SocketDef):
|
class ExprSocketDef(base.SocketDef):
|
||||||
|
"""Interface for defining an `ExprSocket`."""
|
||||||
|
|
||||||
socket_type: ct.SocketType = ct.SocketType.Expr
|
socket_type: ct.SocketType = ct.SocketType.Expr
|
||||||
active_kind: typ.Literal[
|
|
||||||
ct.FlowKind.Value,
|
|
||||||
ct.FlowKind.Range,
|
|
||||||
ct.FlowKind.Func,
|
|
||||||
] = ct.FlowKind.Value
|
|
||||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
|
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
|
||||||
|
use_value_range_swapper: bool = False
|
||||||
|
|
||||||
# Socket Interface
|
# Socket Interface
|
||||||
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
||||||
|
@ -1458,7 +1515,7 @@ class ExprSocketDef(base.SocketDef):
|
||||||
# Check ActiveKind and Size
|
# Check ActiveKind and Size
|
||||||
## -> NOTE: This doesn't protect against dynamic changes to either.
|
## -> NOTE: This doesn't protect against dynamic changes to either.
|
||||||
if (
|
if (
|
||||||
self.active_kind == ct.FlowKind.Range
|
self.active_kind is ct.FlowKind.Range
|
||||||
and self.size is not spux.NumberSize1D.Scalar
|
and self.size is not spux.NumberSize1D.Scalar
|
||||||
):
|
):
|
||||||
msg = "Can't have a non-Scalar size when Range is set as the active kind."
|
msg = "Can't have a non-Scalar size when Range is set as the active kind."
|
||||||
|
@ -1504,9 +1561,9 @@ class ExprSocketDef(base.SocketDef):
|
||||||
# - Initialization
|
# - Initialization
|
||||||
####################
|
####################
|
||||||
def init(self, bl_socket: ExprBLSocket) -> None:
|
def init(self, bl_socket: ExprBLSocket) -> None:
|
||||||
bl_socket.active_kind = self.active_kind
|
|
||||||
bl_socket.output_name = self.output_name
|
bl_socket.output_name = self.output_name
|
||||||
bl_socket.use_linked_capabilities = True
|
bl_socket.use_linked_capabilities = True
|
||||||
|
bl_socket.use_value_range_swapper = self.use_value_range_swapper
|
||||||
|
|
||||||
# Socket Interface
|
# Socket Interface
|
||||||
## -> Recall that auto-updates are turned off during init()
|
## -> Recall that auto-updates are turned off during init()
|
||||||
|
@ -1543,6 +1600,25 @@ class ExprSocketDef(base.SocketDef):
|
||||||
# Info Draw
|
# Info Draw
|
||||||
bl_socket.use_info_draw = True
|
bl_socket.use_info_draw = True
|
||||||
|
|
||||||
|
def local_compare(self, bl_socket: ExprBLSocket) -> None:
|
||||||
|
"""Determine whether an updateable socket should be re-initialized from this `SocketDef`."""
|
||||||
|
|
||||||
|
def cmp(attr: str):
|
||||||
|
return getattr(bl_socket, attr) == getattr(self, attr)
|
||||||
|
|
||||||
|
return (
|
||||||
|
bl_socket.use_linked_capabilities
|
||||||
|
and cmp('output_name')
|
||||||
|
and cmp('use_value_range_swapper')
|
||||||
|
and cmp('size')
|
||||||
|
and cmp('mathtype')
|
||||||
|
and cmp('physical_type')
|
||||||
|
and cmp('show_func_ui')
|
||||||
|
and cmp('show_info_columns')
|
||||||
|
and cmp('show_name_selector')
|
||||||
|
and bl_socket.use_info_draw
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Blender Registration
|
# - Blender Registration
|
||||||
|
|
|
@ -117,6 +117,16 @@ class MaxwellBoundCondsBLSocket(base.MaxwellSimSocket):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||||
|
def lazy_func(self) -> ct.FuncFlow:
|
||||||
|
return ct.FuncFlow(
|
||||||
|
func=lambda: self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property()
|
||||||
|
def params(self) -> ct.FuncFlow:
|
||||||
|
return ct.ParamsFlow()
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Socket Configuration
|
# - Socket Configuration
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
from ... import contracts as ct
|
from ... import contracts as ct
|
||||||
from .. import base
|
from .. import base
|
||||||
|
|
||||||
|
@ -32,6 +34,9 @@ class MaxwellFDTDSimSocketDef(base.SocketDef):
|
||||||
def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None:
|
def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def local_compare(self, _: MaxwellFDTDSimBLSocket) -> None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Blender Registration
|
# - Blender Registration
|
||||||
|
|
|
@ -73,7 +73,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
||||||
def value(self, eps_rel: tuple[float, float]) -> None:
|
def value(self, eps_rel: tuple[float, float]) -> None:
|
||||||
self.eps_rel = eps_rel
|
self.eps_rel = eps_rel
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'value', 'differentiable'})
|
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||||
def lazy_func(self) -> ct.FuncFlow:
|
def lazy_func(self) -> ct.FuncFlow:
|
||||||
return ct.FuncFlow(
|
return ct.FuncFlow(
|
||||||
func=lambda: self.value,
|
func=lambda: self.value,
|
||||||
|
@ -82,7 +82,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'differentiable'})
|
@bl_cache.cached_bl_property(depends_on={'differentiable'})
|
||||||
def params(self) -> ct.FuncFlow:
|
def params(self) -> ct.FuncFlow:
|
||||||
return ct.ParamsFlow(is_differentiable=self.differentiable)
|
return ct.ParamsFlow()
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellMonitorBLSocket(base.MaxwellSimSocket):
|
||||||
class MaxwellMonitorSocketDef(base.SocketDef):
|
class MaxwellMonitorSocketDef(base.SocketDef):
|
||||||
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor
|
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor
|
||||||
|
|
||||||
is_list: bool = False
|
|
||||||
|
|
||||||
def init(self, bl_socket: MaxwellMonitorBLSocket) -> None:
|
def init(self, bl_socket: MaxwellMonitorBLSocket) -> None:
|
||||||
if self.is_list:
|
pass
|
||||||
bl_socket.active_kind = ct.FlowKind.Array
|
|
||||||
|
def local_compare(self, _: MaxwellMonitorBLSocket) -> None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -55,6 +55,16 @@ class MaxwellSimGridBLSocket(base.MaxwellSimSocket):
|
||||||
min_steps_per_wvl=self.min_steps_per_wl,
|
min_steps_per_wvl=self.min_steps_per_wl,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'value'})
|
||||||
|
def lazy_func(self) -> ct.FuncFlow:
|
||||||
|
return ct.FuncFlow(
|
||||||
|
func=lambda: self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property()
|
||||||
|
def params(self) -> ct.FuncFlow:
|
||||||
|
return ct.ParamsFlow()
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Socket Configuration
|
# - Socket Configuration
|
||||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellSourceBLSocket(base.MaxwellSimSocket):
|
||||||
class MaxwellSourceSocketDef(base.SocketDef):
|
class MaxwellSourceSocketDef(base.SocketDef):
|
||||||
socket_type: ct.SocketType = ct.SocketType.MaxwellSource
|
socket_type: ct.SocketType = ct.SocketType.MaxwellSource
|
||||||
|
|
||||||
is_list: bool = False
|
|
||||||
|
|
||||||
def init(self, bl_socket: MaxwellSourceBLSocket) -> None:
|
def init(self, bl_socket: MaxwellSourceBLSocket) -> None:
|
||||||
if self.is_list:
|
pass
|
||||||
bl_socket.active_kind = ct.FlowKind.Array
|
|
||||||
|
def local_compare(self, _: MaxwellSourceBLSocket) -> None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -29,11 +29,11 @@ class MaxwellStructureBLSocket(base.MaxwellSimSocket):
|
||||||
class MaxwellStructureSocketDef(base.SocketDef):
|
class MaxwellStructureSocketDef(base.SocketDef):
|
||||||
socket_type: ct.SocketType = ct.SocketType.MaxwellStructure
|
socket_type: ct.SocketType = ct.SocketType.MaxwellStructure
|
||||||
|
|
||||||
is_list: bool = False
|
|
||||||
|
|
||||||
def init(self, bl_socket: MaxwellStructureBLSocket) -> None:
|
def init(self, bl_socket: MaxwellStructureBLSocket) -> None:
|
||||||
if self.is_list:
|
pass
|
||||||
bl_socket.active_kind = ct.FlowKind.Array
|
|
||||||
|
def local_compare(self, _: MaxwellStructureBLSocket) -> None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
|
|
||||||
"""Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes."""
|
"""Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes."""
|
||||||
|
|
||||||
|
from ..keyed_cache import KeyedCache, keyed_cache
|
||||||
from .bl_field import BLField
|
from .bl_field import BLField
|
||||||
from .bl_prop import BLProp, BLPropType
|
from .bl_prop import BLProp, BLPropType
|
||||||
from .cached_bl_property import CachedBLProperty, cached_bl_property
|
from .cached_bl_property import CachedBLProperty, cached_bl_property
|
||||||
from .keyed_cache import KeyedCache, keyed_cache
|
|
||||||
from .managed_cache import invalidate_nonpersist_instance_id
|
from .managed_cache import invalidate_nonpersist_instance_id
|
||||||
from .signal import Signal
|
from .signal import Signal
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from types import MappingProxyType
|
||||||
import bpy
|
import bpy
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache, logger
|
from blender_maxwell.utils import bl_cache, logger
|
||||||
|
from blender_maxwell.utils.keyed_cache import keyed_cache
|
||||||
|
|
||||||
InstanceID: typ.TypeAlias = str ## Stringified UUID4
|
InstanceID: typ.TypeAlias = str ## Stringified UUID4
|
||||||
|
|
||||||
|
@ -220,11 +221,14 @@ class BLInstance:
|
||||||
for str_search_prop_name in self.blfields_str_search:
|
for str_search_prop_name in self.blfields_str_search:
|
||||||
setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch)
|
setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch)
|
||||||
|
|
||||||
|
@keyed_cache(
|
||||||
|
exclude={'self'}, ## No dynamic elements of 'self' can be used.
|
||||||
|
)
|
||||||
def trace_blfields_to_clear(
|
def trace_blfields_to_clear(
|
||||||
self,
|
self,
|
||||||
prop_name: str,
|
prop_name: str,
|
||||||
prev_blfields_to_clear: list[
|
prev_blfields_to_clear: tuple[
|
||||||
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']]
|
tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']], ...
|
||||||
] = (),
|
] = (),
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Invalidates all properties that depend on `prop_name`.
|
"""Invalidates all properties that depend on `prop_name`.
|
||||||
|
@ -239,7 +243,7 @@ class BLInstance:
|
||||||
All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`).
|
All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`).
|
||||||
"""
|
"""
|
||||||
if prev_blfields_to_clear:
|
if prev_blfields_to_clear:
|
||||||
blfields_to_clear = prev_blfields_to_clear.copy()
|
blfields_to_clear = list(prev_blfields_to_clear)
|
||||||
else:
|
else:
|
||||||
blfields_to_clear = []
|
blfields_to_clear = []
|
||||||
|
|
||||||
|
@ -268,7 +272,7 @@ class BLInstance:
|
||||||
if dst_prop_name in self.blfields:
|
if dst_prop_name in self.blfields:
|
||||||
blfields_to_clear += self.trace_blfields_to_clear(
|
blfields_to_clear += self.trace_blfields_to_clear(
|
||||||
dst_prop_name,
|
dst_prop_name,
|
||||||
prev_blfields_to_clear=blfields_to_clear,
|
prev_blfields_to_clear=tuple(blfields_to_clear),
|
||||||
)
|
)
|
||||||
|
|
||||||
match (bool(prev_blfields_to_clear), bool(blfields_to_clear)):
|
match (bool(prev_blfields_to_clear), bool(blfields_to_clear)):
|
||||||
|
@ -297,7 +301,7 @@ class BLInstance:
|
||||||
## -> As such, deduplication would not be wrong, just extraneous.
|
## -> As such, deduplication would not be wrong, just extraneous.
|
||||||
## -> Since invalidation is in a hot-loop, don't do such things.
|
## -> Since invalidation is in a hot-loop, don't do such things.
|
||||||
case (True, True):
|
case (True, True):
|
||||||
return blfields_to_clear
|
return list(reversed(dict.fromkeys(reversed(blfields_to_clear))))
|
||||||
|
|
||||||
def clear_blfields_after(self, prop_name: str) -> list[str]:
|
def clear_blfields_after(self, prop_name: str) -> list[str]:
|
||||||
"""Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`.
|
"""Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`.
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
"""Useful image processing operations for use in the addon."""
|
"""Useful image processing operations for use in the addon."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -26,7 +27,6 @@ import matplotlib
|
||||||
import matplotlib.axis as mpl_ax
|
import matplotlib.axis as mpl_ax
|
||||||
import matplotlib.backends.backend_agg
|
import matplotlib.backends.backend_agg
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
from blender_maxwell import contracts as ct
|
from blender_maxwell import contracts as ct
|
||||||
|
@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
|
||||||
####################
|
####################
|
||||||
# - MPL Helpers
|
# - MPL Helpers
|
||||||
####################
|
####################
|
||||||
# @functools.lru_cache(maxsize=16)
|
@functools.lru_cache(maxsize=4)
|
||||||
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
||||||
fig = matplotlib.figure.Figure(
|
fig = matplotlib.figure.Figure(
|
||||||
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
|
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
|
||||||
|
|
|
@ -18,11 +18,15 @@ import functools
|
||||||
import inspect
|
import inspect
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_instance, logger, serialize
|
from blender_maxwell.utils import logger, serialize
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BLInstance(typ.Protocol):
|
||||||
|
instance_id: str
|
||||||
|
|
||||||
|
|
||||||
class KeyedCache:
|
class KeyedCache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -75,8 +79,8 @@ class KeyedCache:
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
self,
|
self,
|
||||||
bl_instance: bl_instance.BLInstance | None,
|
bl_instance: BLInstance | None,
|
||||||
owner: type[bl_instance.BLInstance],
|
owner: type[BLInstance],
|
||||||
) -> typ.Callable:
|
) -> typ.Callable:
|
||||||
_func = functools.partial(self, bl_instance)
|
_func = functools.partial(self, bl_instance)
|
||||||
_func.invalidate = functools.partial(
|
_func.invalidate = functools.partial(
|
||||||
|
@ -110,7 +114,7 @@ class KeyedCache:
|
||||||
|
|
||||||
def invalidate(
|
def invalidate(
|
||||||
self,
|
self,
|
||||||
bl_instance: bl_instance.BLInstance | None,
|
bl_instance: BLInstance | None,
|
||||||
**arguments: dict[str, typ.Any],
|
**arguments: dict[str, typ.Any],
|
||||||
) -> dict[str, typ.Any]:
|
) -> dict[str, typ.Any]:
|
||||||
# Determine Wildcard Arguments
|
# Determine Wildcard Arguments
|
|
@ -264,13 +264,16 @@ class SimSymbol(pyd.BaseModel):
|
||||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Labels
|
# - Core
|
||||||
####################
|
####################
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Usable name for the symbol."""
|
"""Usable name for the symbol."""
|
||||||
return self.sym_name.name
|
return self.sym_name.name
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Labels
|
||||||
|
####################
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def name_pretty(self) -> str:
|
def name_pretty(self) -> str:
|
||||||
"""Pretty (possibly unicode) name for the thing."""
|
"""Pretty (possibly unicode) name for the thing."""
|
||||||
|
@ -307,6 +310,8 @@ class SimSymbol(pyd.BaseModel):
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def plot_label(self) -> str:
|
def plot_label(self) -> str:
|
||||||
"""Pretty plot-oriented label."""
|
"""Pretty plot-oriented label."""
|
||||||
|
if self.unit is None:
|
||||||
|
return self.name_pretty
|
||||||
return f'{self.name_pretty} ({self.unit_label})'
|
return f'{self.name_pretty} ({self.unit_label})'
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -420,6 +425,11 @@ class SimSymbol(pyd.BaseModel):
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def is_nonzero(self) -> bool:
|
def is_nonzero(self) -> bool:
|
||||||
|
"""Whether or not the value of this symbol can ever be $0$.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
|
||||||
|
"""
|
||||||
if self.exclude_zero:
|
if self.exclude_zero:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -441,6 +451,18 @@ class SimSymbol(pyd.BaseModel):
|
||||||
)
|
)
|
||||||
return check_real_domain(self.domain)
|
return check_real_domain(self.domain)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def can_diff(self) -> bool:
|
||||||
|
"""Whether this symbol can be used as the input / output variable when differentiating."""
|
||||||
|
# Check Constants
|
||||||
|
## -> Constants (w/pinned values) are never differentiable.
|
||||||
|
if self.is_constant:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO: Discontinuities (especially across 0)?
|
||||||
|
|
||||||
|
return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
|
@ -664,8 +686,10 @@ class SimSymbol(pyd.BaseModel):
|
||||||
res = spux.strip_unit_system(sp_obj)
|
res = spux.strip_unit_system(sp_obj)
|
||||||
|
|
||||||
# Broadcast Expansion
|
# Broadcast Expansion
|
||||||
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
|
if (self.rows > 1 or self.cols > 1) and not isinstance(
|
||||||
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
res, sp.MatrixBase | sp.MatrixSymbol
|
||||||
|
):
|
||||||
|
res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -753,7 +777,9 @@ class SimSymbol(pyd.BaseModel):
|
||||||
unit = None
|
unit = None
|
||||||
|
|
||||||
# Rows/Cols from Expr (if Matrix)
|
# Rows/Cols from Expr (if Matrix)
|
||||||
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
|
rows, cols = (
|
||||||
|
expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
return SimSymbol(
|
return SimSymbol(
|
||||||
sym_name=sym_name,
|
sym_name=sym_name,
|
||||||
|
|
Loading…
Reference in New Issue