diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py index 0303908..a6c2b54 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py @@ -14,12 +14,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import dataclasses import functools import typing as typ import jaxtyping as jtyp import numpy as np +import pydantic as pyd import sympy as sp from blender_maxwell.utils import extra_sympy_units as spux @@ -29,8 +29,7 @@ log = logger.get(__name__) # TODO: Our handling of 'is_sorted' is sloppy and probably wrong. -@dataclasses.dataclass(frozen=True, kw_only=True) -class ArrayFlow: +class ArrayFlow(pyd.BaseModel): """A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking. While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing. @@ -41,7 +40,9 @@ class ArrayFlow: None if unitless. """ - values: jtyp.Shaped[jtyp.Array, '...'] + model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True) + + values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type unit: spux.Unit | None = None is_sorted: bool = False diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py index 8db2fc9..4ff917e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py @@ -18,6 +18,7 @@ import enum import functools import typing as typ +from blender_maxwell.contracts import BLEnumElement from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger from blender_maxwell.utils.staticproperty import staticproperty @@ -99,6 +100,17 @@ class FlowKind(enum.StrEnum): def to_icon(_: typ.Self) -> str: return '' + @property + def name(self) -> str: + return FlowKind.to_name(self) + + @property + def icon(self) -> str: + return FlowKind.to_icon(self) + + def bl_enum_element(self, i) -> BLEnumElement: + return (str(self), self.name, self.name, self.icon, i) + #################### # - Static Properties #################### @@ -162,7 +174,7 @@ class FlowKind(enum.StrEnum): def socket_shape(self) -> str: """Return the socket shape associated with this `FlowKind`. - **ONLY** valid for `FlowKind`s that can be considered "active". + Should generally only be used with `active_kinds`. Raises: ValueError: If this `FlowKind` cannot ever be considered "active". @@ -172,7 +184,7 @@ class FlowKind(enum.StrEnum): FlowKind.Array: 'SQUARE', FlowKind.Range: 'SQUARE', FlowKind.Func: 'DIAMOND', - }[self] + }.get(self, 'CIRCLE') #################### # - Class Methods diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py index ccf798f..c8b32f0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py @@ -14,13 +14,217 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import dataclasses +r"""Implements the core of the math system via `FuncFlow`, which allows high-performance, fully-expressive workflows with data that can be "very large", and/or whose input parameters are not yet fully known. + +# Introduction +When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**. +Doing so has several advantages: + +- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy, greatly boosting the will to experiment and ultimate productivity. +- **Symbolic**: Since no numerical math is being done yet, we can inject symbolic variables at-will, enabling effortless ex. parameter sweeping, band-structure generation, differentiable can choose to keep our input parameters as symbolic variables with no performance impact. +- **Performant**: Since the data pipeline is built as a single function w/o side effects, that function can be often be JIT-optimized for highly-optimized execution on the instruction sets used by modern massively-parallel devices, like modern CPUs (SSE/AVX), GPUs (ex. PTX), and HPC clusters (network-sharding to ARM/x86). + +The result is a math system optimized for the analysis typically needed in electrodynamic contexts, prioritizing clarity and flexibility at soft-real-time, even with gigabytes of data on relatively weak hardware. + +## Strongly Related FlowKinds +For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel: + +- `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`. +- `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic. + +For more, please see the documentation for each. + +## Non-Mathematical Use +Of course, there are many interesting uses of incremental function composition that aren't mathematical. + +For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel. + + + +# Lazy Math: Theoretical Foundation +This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math. +Thus, it makes sense to describe the math system here. + +## `depth=0`: Root Function +To start a composition chain, a function with no inputs must be defined as the "root", or "bottom". + +$$ + f_0:\ \ \ \ \biggl( + \underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\ + \underbrace{ + \begin{bmatrix} k_1 \\ v_1\end{bmatrix}, + \begin{bmatrix} k_2 \\ v_2\end{bmatrix}, + ..., + \begin{bmatrix} k_q \\ v_q\end{bmatrix} + }_{\texttt{kwargs}} + \biggr) \to \text{output}_0 +$$ + +In Python, such a construction would look like this: + +```python +# Presume 'A0', 'KV0' contain only the args/kwargs for f_0 +## 'A0', 'KV0' are of length 'p' and 'q' +def f_0(*args, **kwargs): ... + +lazy_func_0 = FuncFlow( + func=f_0, + func_args=[(a_i, type(a_i)) for a_i in A0], + func_kwargs={k: v for k,v in KV0}, +) +output_0 = lazy_func.func(*A0_computed, **KV0_computed) +``` + +## `depth>0`: Composition Chaining +So far, so easy. +Now, let's add a function that uses the result of $f_0$, without yet computing it. + +$$ + f_1:\ \ \ \ \biggl( + f_0(...),\ \ + \underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\ + \underbrace{\biggl\{ + \begin{bmatrix} k_i \\ v_i\end{bmatrix} + \biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}} + \biggr) \to \text{output}_1 +$$ + +Note: +- $f_1$ must take the arguments of both $f_0$ and $f_1$. +- The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments". + +In other words, **there's suddenly a lot to manage**. +Even worse, the bigger the $n$, the more complexity we must real with. + +This is where the Python version starts to show its purpose: + +```python +# Presume 'A1', 'K1' contain only the args/kwarg names for f_1 +## 'A1', 'KV1' are therefore of length 'r' and 's' +def f_1(output_0, *args, **kwargs): ... + +lazy_func_1 = lazy_func_0.compose_within( + enclosing_func=f_1, + enclosing_func_args=[(a_i, type(a_i)) for a_i in A1], + enclosing_func_kwargs={k: type(v) for k,v in K1}, +) + +A_computed = A0_computed + A1_computed +KW_computed = KV0_computed + KV1_computed +output_1 = lazy_func_1.func(*A_computed, **KW_computed) +``` + +By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this. + +## `max depth`: "Realization" +So, we've composed a bunch of functions of functions of ... +We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object. + +But it'd be pointless to just compose away forever. +We do actually need the data that they claim to compute now: + +```python +# A_all and KW_all must be tracked on the side. +output_n = lazy_func_n.func(*A_all, **KW_all) +``` + +Of course, this comes with enormous overhead. +Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially. + +That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all: + +```python +output_n = lazy_func_n.func_jax(*A_all, **KW_all) +``` + +What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU. +With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices. + +See the documentation of `Func.func_jax()` for more information on this process. + + + +# Lazy Math: Practical Considerations +By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance. + +## UX +UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective. + +The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_. + +Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**. +In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds. +Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software. + +Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided. +In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough. +**This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need. + +For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it. +It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis. +This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less. + +Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes. +For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works). +To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises. +As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag. + +In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters). +Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons. +With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better. + +## Flexibility +Large-scale math is done on tensors, whether one knows (or likes!) it or not. +To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do: + +- **Naming**: What _is_ each axis? + Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays. + Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted. + _Sometimes arrays with named axes are called "structured arrays". + +- **Coordinates**: What do the indices of each axis really _mean_? + For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$. + But they can be limited to between index `23` to `298`. + I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough. + +Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible. + +In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain. +**Please consult the `InfoFlow` documentation for more**. + +## Performance +All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`. + +It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes. +Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data. + +See the `ParamFlow` documentation for more information. + + + +# Conclusion +There is, of course, a lot more to say about the math system in general. +A few teasers of what nodes can do with this system: + +**Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`. + `jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability. + When used in +**Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes + We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations. +**Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_. + This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes. + +But above all, we hope that this math system is fun, practical, and maybe even interesting. +""" + import functools import typing as typ from types import MappingProxyType import jax import jaxtyping as jtyp +import pydantic as pyd +import sympy as sp from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols @@ -35,210 +239,12 @@ log = logger.get(__name__) LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any] -@dataclasses.dataclass(frozen=True, kw_only=True) -class FuncFlow: +class FuncFlow(pyd.BaseModel): r"""Defines a flow of data as incremental function composition. + For theoretical information, please see the documentation of this module. For specific math system usage instructions, please consult the documentation of relevant nodes. - # Introduction - When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**. - Doing so has several advantages: - - - **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy. - - **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact. - - **Performant**: Since no operations are happening, the UI feels fast and snappy. - - ## Strongly Related FlowKinds - For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel: - - - `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`. - - `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic. - - For more, please see the documentation for each. - - ## Non-Mathematical Use - Of course, there are many interesting uses of incremental function composition that aren't mathematical. - - For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel. - - - - # Lazy Math: Theoretical Foundation - This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math. - Thus, it makes sense to describe the math system here. - - ## `depth=0`: Root Function - To start a composition chain, a function with no inputs must be defined as the "root", or "bottom". - - $$ - f_0:\ \ \ \ \biggl( - \underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\ - \underbrace{ - \begin{bmatrix} k_1 \\ v_1\end{bmatrix}, - \begin{bmatrix} k_2 \\ v_2\end{bmatrix}, - ..., - \begin{bmatrix} k_q \\ v_q\end{bmatrix} - }_{\texttt{kwargs}} - \biggr) \to \text{output}_0 - $$ - - In Python, such a construction would look like this: - - ```python - # Presume 'A0', 'KV0' contain only the args/kwargs for f_0 - ## 'A0', 'KV0' are of length 'p' and 'q' - def f_0(*args, **kwargs): ... - - lazy_func_0 = FuncFlow( - func=f_0, - func_args=[(a_i, type(a_i)) for a_i in A0], - func_kwargs={k: v for k,v in KV0}, - ) - output_0 = lazy_func.func(*A0_computed, **KV0_computed) - ``` - - ## `depth>0`: Composition Chaining - So far, so easy. - Now, let's add a function that uses the result of $f_0$, without yet computing it. - - $$ - f_1:\ \ \ \ \biggl( - f_0(...),\ \ - \underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\ - \underbrace{\biggl\{ - \begin{bmatrix} k_i \\ v_i\end{bmatrix} - \biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}} - \biggr) \to \text{output}_1 - $$ - - Note: - - $f_1$ must take the arguments of both $f_0$ and $f_1$. - - The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments". - - In other words, **there's suddenly a lot to manage**. - Even worse, the bigger the $n$, the more complexity we must real with. - - This is where the Python version starts to show its purpose: - - ```python - # Presume 'A1', 'K1' contain only the args/kwarg names for f_1 - ## 'A1', 'KV1' are therefore of length 'r' and 's' - def f_1(output_0, *args, **kwargs): ... - - lazy_func_1 = lazy_func_0.compose_within( - enclosing_func=f_1, - enclosing_func_args=[(a_i, type(a_i)) for a_i in A1], - enclosing_func_kwargs={k: type(v) for k,v in K1}, - ) - - A_computed = A0_computed + A1_computed - KW_computed = KV0_computed + KV1_computed - output_1 = lazy_func_1.func(*A_computed, **KW_computed) - ``` - - By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this. - - ## `max depth`: "Realization" - So, we've composed a bunch of functions of functions of ... - We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object. - - But it'd be pointless to just compose away forever. - We do actually need the data that they claim to compute now: - - ```python - # A_all and KW_all must be tracked on the side. - output_n = lazy_func_n.func(*A_all, **KW_all) - ``` - - Of course, this comes with enormous overhead. - Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially. - - That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all: - - ```python - output_n = lazy_func_n.func_jax(*A_all, **KW_all) - ``` - - What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU. - With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices. - - See the documentation of `Func.func_jax()` for more information on this process. - - - - # Lazy Math: Practical Considerations - By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance. - - ## UX - UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective. - - The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_. - - Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**. - In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds. - Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software. - - Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided. - In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough. - **This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need. - - For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it. - It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis. - This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less. - - Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes. - For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works). - To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises. - As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag. - - In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters). - Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons. - With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better. - - ## Flexibility - Large-scale math is done on tensors, whether one knows (or likes!) it or not. - To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do: - - - **Naming**: What _is_ each axis? - Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays. - Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted. - _Sometimes arrays with named axes are called "structured arrays". - - - **Coordinates**: What do the indices of each axis really _mean_? - For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$. - But they can be limited to between index `23` to `298`. - I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough. - - Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible. - - In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain. - **Please consult the `InfoFlow` documentation for more**. - - ## Performance - All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`. - - It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes. - Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data. - - See the `ParamFlow` documentation for more information. - - - - # Conclusion - There is, of course, a lot more to say about the math system in general. - A few teasers of what nodes can do with this system: - - **Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`. - `jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability. - When used in - **Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes - We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations. - **Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_. - This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes. - - But above all, we hope that this math system is fun, practical, and maybe even interesting. - Attributes: func: The function that generates the represented value. func_args: The constrained identity of all positional arguments to the function. @@ -247,14 +253,16 @@ class FuncFlow: See the documentation of `self.func_jax()`. """ + model_config = pyd.ConfigDict(frozen=True) + func: LazyFunction - func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list) - func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field( - default_factory=dict - ) + func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list) + func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict) + func_output: sim_symbols.SimSymbol | None = None + supports_jax: bool = False - concatenated: bool = False + is_concatenated: bool = False #################### # - Functions @@ -318,6 +326,7 @@ class FuncFlow: {} ), ) -> typ.Self: + """Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.""" if self.supports_jax: return self.func_jax( *params.scaled_func_args(symbol_values), @@ -371,14 +380,55 @@ class FuncFlow: return data | {info.output: self.realize(params, symbol_values=symbol_values)} + def realize_partial( + self, params: ParamsFlow + ) -> typ.Callable[ + [int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...], + jtyp.Inexact[jtyp.Array, '...'], + ]: + """Create a purely-numerical function, which takes only numerical. + + The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`. + + This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`. + By using `realize_partial()`, two things are ensured: + + - Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values. + - Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function. + + Notes: + Be **very careful about units**. + Ideally, the bottom function should use `.scale_to_unit()` before invoking `.compose_within()` with the output of this function. + """ + pre_realized_syms = list( + params.realize_symbols(params.realized_symbols, allow_partial=True).values() + ) + + def realizer( + *sym_args: int | float | complex | jtyp.Inexact[jtyp.Array, '...'], + ) -> jtyp.Inexact[jtyp.Array, '...']: + return self.func( + *[ + func_arg_n(*sym_args, *pre_realized_syms) + for func_arg_n in params.func_args_n + ], + **{ + func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms) + for func_arg_name, func_kwarg_n in params.func_kwargs_n.items() + }, + ) + + return realizer + #################### - # - Composition Operations + # - Operations #################### def compose_within( self, enclosing_func: LazyFunction, - enclosing_func_args: list[type] = (), - enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), + enclosing_func_args: list[sim_symbols.SimSymbol] = (), + enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}), + enclosing_func_output: sim_symbols.SimSymbol | None = None, supports_jax: bool = False, ) -> typ.Self: """Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it. @@ -415,6 +465,10 @@ class FuncFlow: Returns: A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function). """ + ## TODO: Support unit system conversion at the point of composition. + ## -- This may require us to track the units of the function output. + ## TODO: Support JAX-evaluation when jax support changes from True to False. + ## -- This would allow big data flows to compose performantly as arguments into non-JAX functions. return FuncFlow( func=lambda *args, **kwargs: enclosing_func( self.func( @@ -426,6 +480,7 @@ class FuncFlow: ), func_args=self.func_args + list(enclosing_func_args), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), + func_output=enclosing_func_output, supports_jax=self.supports_jax and supports_jax, ) @@ -472,7 +527,7 @@ class FuncFlow: *list(args[: len(self.func_args)]), **{k: v for k, v in kwargs.items() if k in self.func_kwargs}, ) - if not self.concatenated: + if not self.is_concatenated: return (ret,) return ret @@ -487,5 +542,83 @@ class FuncFlow: func_args=self.func_args + other.func_args, func_kwargs=self.func_kwargs | other.func_kwargs, supports_jax=self.supports_jax and other.supports_jax, - concatenated=True, + is_concatenated=True, ) + + def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self: + """Encloses this function in a unit-converting function, whose output is a converted, unitless scalar. + + `unit` must be manually guaranteed to be compatible with `self.unit`. + """ + if self.func_output is not None: + # Retrieve Output Unit + output_unit = self.func_output.unit + + # Compile Efficient Unit-Conversion Function + a = self.func_output.mathtype.sp_symbol_a + unit_convert_expr = ( + spux.scale_to_unit(a * output_unit, unit) + if self.func_output.unit is not None + else a + ) + unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax') + + # Compose Unit-Converted FuncFlow + return self.compose_within( + enclosing_func=unit_convert_func, + supports_jax=True, + enclosing_func_output=self.func_output.update(unit=unit), + ) + + msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})' + raise ValueError(msg) + + def scale_to_unit_system( + self, unit_system: spux.UnitSystem | None = None + ) -> typ.Self: + """Encloses this function in a unit-converting function, whose output is a converted, unitless scalar. + + Using `self.output_symbol`, which tracks the units of the output, we can determine a scaling factor to multiply the (numerical) function output by in order to conform it to the given unit system. + + In general, **don't use this**. + Any superfluous numerical operations in a data pipeline can enhance instabilities and interfere with JIT-optimization (floating-point arithmetic isn't commutative, for example). + However, occasionally, we need to "intercept" a lazy data flow, for example when realizing a `FlowKind.Value` that doesn't understand symbols or units - but which only accepts a float/complex scalar/array with pre-determined unit convention. + + For this purpose alone, this method is provided to pre-scale a `FuncFlow`, just before using `realize()` / `__or__` and then `realize()`. + **To encourage proper usage** (and ease implementation), the output unit in `self.func_output` of the output will be reset to `None` - indicating that the output can only be handled as a unitless scalar w/semantic meaning tracked elsewhere. + + Notes: + **ONLY** use with output types that support meaningful arbitrary multiplication. + + A scale-only sympy expression will be used to produce an optimized JAX function of a single variable, which will then be composed onto the existing `FuncFlow`. + + Parameters: + unit_system: The unit system to conform the function output to. + + Returns: + A new `FuncFlow` that conforms to the new unit, but is itself now considered unitless. + """ + if self.func_output is not None: + # Retrieve Output Unit + output_unit = self.func_output.unit + + # Compile Efficient Unit-Conversion Function + a = self.func_output.mathtype.sp_symbol_a + unit_convert_expr = ( + spux.strip_unit_system( + spux.convert_to_unit_system(a * output_unit, unit_system) + ) + if self.func_output.unit is not None + else a + ) + unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax') + + # Compose Unit-Converted FuncFlow + return self.compose_within( + enclosing_func=unit_convert_func, + supports_jax=True, + enclosing_func_output=self.func_output.update(unit=None), + ) + + msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})' + raise ValueError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py index 55cc3b5..a204eef 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py @@ -14,17 +14,15 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import dataclasses import enum import functools import typing as typ -from fractions import Fraction from types import MappingProxyType import jax.numpy as jnp import jaxtyping as jtyp +import pydantic as pyd import sympy as sp -import sympy.physics.units as spu from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols @@ -61,8 +59,7 @@ class ScalingMode(enum.StrEnum): return '' -@dataclasses.dataclass(frozen=True, kw_only=True) -class RangeFlow: +class RangeFlow(pyd.BaseModel): r"""Represents a finite spaced array using symbolic boundary expressions. Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. @@ -92,8 +89,10 @@ class RangeFlow: symbols: Set of variables from which `start` and/or `stop` are determined. """ - start: spux.ScalarUnitlessComplexExpr - stop: spux.ScalarUnitlessComplexExpr + model_config = pyd.ConfigDict(frozen=True) + + start: spux.ScalarUnitlessRealExpr + stop: spux.ScalarUnitlessRealExpr steps: int = 0 scaling: ScalingMode = ScalingMode.Lin @@ -102,7 +101,7 @@ class RangeFlow: symbols: frozenset[sim_symbols.SimSymbol] = frozenset() # Helper Attributes - pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None + pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None #################### # - SimSymbol Interop @@ -218,14 +217,26 @@ class RangeFlow: ) return combined_mathtype - @property + @functools.cached_property def ideal_midpoint(self) -> spux.SympyExpr: return (self.stop + self.start) / 2 - @property + @functools.cached_property def ideal_range(self) -> spux.SympyExpr: return self.stop - self.start + @functools.cached_property + def ideal_step_size(self) -> spux.SympyExpr: + return self.ideal_range / (self.steps - 1) + + @functools.cached_property + def is_always_nonzero(self) -> spux.SympyExpr: + if self.start > 0 or self.stop < 0: + return True + + is_zero = (self.start % self.ideal_step_size).is_zero + return is_zero if is_zero is not None else False + #################### # - Methods #################### @@ -452,7 +463,7 @@ class RangeFlow: symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( {} ), - ) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]: + ) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]: """Realize **all** input symbols to the `RangeFlow`. Parameters: @@ -480,7 +491,7 @@ class RangeFlow: raise NotImplementedError(msg) realized_syms |= {sym: v} - + return realized_syms msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})' raise ValueError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py index fb445a6..20d6562 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py @@ -14,13 +14,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import dataclasses import functools import typing as typ from fractions import Fraction from types import MappingProxyType import jaxtyping as jtyp +import pydantic as pyd import sympy as sp from blender_maxwell.utils import extra_sympy_units as spux @@ -28,31 +28,35 @@ from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow from .expr_info import ExprInfo -from .flow_kinds import FlowKind from .lazy_range import RangeFlow log = logger.get(__name__) -@dataclasses.dataclass(frozen=True, kw_only=True) -class ParamsFlow: +class ParamsFlow(pyd.BaseModel): """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. Returns: All symbols valid for use in the expression. """ - arg_targets: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list) - kwarg_targets: list[str, sim_symbols.SimSymbol] = dataclasses.field( - default_factory=dict - ) + model_config = pyd.ConfigDict(frozen=True) - func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) - func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) + arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list) + kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict) + + func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list) + func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict) symbols: frozenset[sim_symbols.SimSymbol] = frozenset() + realized_symbols: dict[ + sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow + ] = pyd.Field(default_factory=dict) - is_differentiable: bool = False + @functools.cached_property + def diff_symbols(self) -> set[sim_symbols.SimSymbol]: + """Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments.""" + return {sym for sym in self.symbols if sym.can_diff} #################### # - Symbols @@ -78,6 +82,27 @@ class ParamsFlow: """ return [sym.sp_symbol_matsym for sym in self.sorted_symbols] + @functools.cached_property + def all_sorted_symbols(self) -> list[sim_symbols.SimSymbol]: + """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. + + Returns: + All symbols valid for use in the expression. + """ + key_func = lambda sym: sym.name # noqa: E731 + return sorted(self.symbols, key=key_func) + sorted( + self.realized_symbols.keys(), key=key_func + ) + + @functools.cached_property + def all_sorted_sp_symbols(self) -> list[sim_symbols.SimSymbol]: + """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. + + Returns: + All symbols valid for use in the expression. + """ + return [sym.sp_symbol_matsym for sym in self.all_sorted_symbols] + #################### # - JIT'ed Callables for Numerical Function Arguments #################### @@ -101,7 +126,7 @@ class ParamsFlow: """ return [ sp.lambdify( - self.sorted_sp_symbols, + self.all_sorted_sp_symbols, target_sym.conform(func_arg, strip_unit=True), 'jax', ) @@ -127,7 +152,7 @@ class ParamsFlow: """ return { key: sp.lambdify( - self.sorted_sp_symbols, + self.all_sorted_sp_symbols, self.kwarg_targets[key].conform(func_arg, strip_unit=True), 'jax', ) @@ -142,8 +167,9 @@ class ParamsFlow: symbol_values: dict[ sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow ] = MappingProxyType({}), + allow_partial: bool = False, ) -> dict[ - sp.Symbol, + sim_symbols.SimSymbol, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :, ]: """Fully realize all symbols by assigning them a value. @@ -160,10 +186,12 @@ class ParamsFlow: Returns: A dictionary almost with `.subs()`, other than `jax` arrays. """ - if set(self.symbols) == set(symbol_values.keys()): + if allow_partial or set(self.all_sorted_symbols) == set(symbol_values.keys()): realized_syms = {} - for sym in self.sorted_symbols: - sym_value = symbol_values[sym] + for sym in self.all_sorted_symbols: + sym_value = symbol_values.get(sym) + if sym_value is None and allow_partial: + continue if isinstance(sym_value, spux.SympyType): v = sym.scale(sym_value) @@ -214,7 +242,9 @@ class ParamsFlow: Parameters: symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`). """ - realized_symbols = list(self.realize_symbols(symbol_values).values()) + realized_symbols = list( + self.realize_symbols(symbol_values | self.realized_symbols).values() + ) return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n] def scaled_func_kwargs( @@ -227,10 +257,11 @@ class ParamsFlow: Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`. """ - realized_symbols = self.realize_symbols(symbol_values) + realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols) + return { - func_arg_name: func_arg_n(**realized_symbols) - for func_arg_name, func_arg_n in self.func_kwargs_n.items() + func_arg_name: func_kwarg_n(**realized_symbols) + for func_arg_name, func_kwarg_n in self.func_kwargs_n.items() } #################### @@ -251,7 +282,7 @@ class ParamsFlow: func_args=self.func_args + other.func_args, func_kwargs=self.func_kwargs | other.func_kwargs, symbols=self.symbols | other.symbols, - is_differentiable=self.is_differentiable and other.is_differentiable, + realized_symbols=self.realized_symbols | other.realized_symbols, ) def compose_within( @@ -261,7 +292,6 @@ class ParamsFlow: enclosing_func_args: list[spux.SympyExpr] = (), enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}), enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(), - enclosing_is_differentiable: bool = False, ) -> typ.Self: return ParamsFlow( arg_targets=self.arg_targets + list(enclosing_arg_targets), @@ -269,13 +299,41 @@ class ParamsFlow: func_args=self.func_args + list(enclosing_func_args), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), symbols=self.symbols | enclosing_symbols, - is_differentiable=( - self.is_differentiable - if not enclosing_symbols - else (self.is_differentiable & enclosing_is_differentiable) - ), + realized_symbols=self.realized_symbols, ) + def realize_partial( + self, + symbol_values: dict[ + sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow + ], + ) -> typ.Self: + """Provide a particular expression/range/array to realize some symbols. + + Essentially removes symbols from `self.symbols`, and adds the symbol w/value to `self.realized_symbols`. + As a result, only the still-unrealized symbols need to be passed at the time of realization (using ex. `self.scaled_func_args()`). + + Parameters: + symbol_values: The value to realize for each `SimSymbol`. + **All keys** must be identically matched to a single element of `self.symbol`. + Can be empty, in which case an identical new `ParamsFlow` will be returned. + + Raises: + ValueError: If any symbol in `symbol_values` + """ + syms = set(symbol_values.keys()) + if syms.issubset(self.symbols) or not syms: + return ParamsFlow( + arg_targets=self.arg_targets, + kwarg_targets=self.kwarg_targets, + func_args=self.func_args, + func_kwargs=self.func_kwargs, + symbols=self.symbols - syms, + realized_symbols=self.realized_symbols | symbol_values, + ) + msg = f'ParamsFlow: Not all partially realized symbols are defined on the ParamsFlow (symbols={self.symbols}, symbol_values={symbol_values})' + raise ValueError(msg) + #################### # - Generate ExprSocketDef #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index 3fbb87b..f796a63 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -16,7 +16,7 @@ """Declares `ManagedBLImage`.""" -# import time +import time import typing as typ import bpy @@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj): dpi: int | None = None, bl_select: bool = False, ): - # times = [time.perf_counter()] + times = ['START', time.perf_counter()] # Compute Plot Dimensions # aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = ( @@ -277,22 +277,22 @@ class ManagedBLImage(base.ManagedObj): # _width_inches, _height_inches, _dpi # ) fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi) - # times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]]) + times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]]) # fig.clear() ax.clear() - # times.append(['Clear Axis', time.perf_counter() - times[0]]) + times.append(['Clear Axis', time.perf_counter() - times[0]]) # Plot w/User Parameter func_plotter(ax) - # times.append(['Plot!', time.perf_counter() - times[0]]) + times.append(['Plot!', time.perf_counter() - times[0]]) # Save Figure to BytesIO canvas.draw() - # times.append(['Draw Pixels', time.perf_counter() - times[0]]) + times.append(['Draw Pixels', time.perf_counter() - times[0]]) canvas_width_px, canvas_height_px = fig.canvas.get_width_height() - # times.append(['Get Canvas Dims', time.perf_counter() - times[0]]) + times.append(['Get Canvas Dims', time.perf_counter() - times[0]]) image_data = ( np.float32( np.flipud( @@ -303,7 +303,7 @@ class ManagedBLImage(base.ManagedObj): ) / 255 ) - # times.append(['Load Data from Canvas', time.perf_counter() - times[0]]) + times.append(['Load Data from Canvas', time.perf_counter() - times[0]]) # Optimized Write to Blender Image bl_image = self.bl_image(canvas_width_px, canvas_height_px, 'RGBA', 'uint8') diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/node_tree.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/node_tree.py index 6243641..af9327c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/node_tree.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/node_tree.py @@ -15,6 +15,8 @@ # along with this program. If not, see . import contextlib +import functools +import queue import typing as typ import bpy @@ -26,6 +28,14 @@ from .managed_objs.managed_bl_image import ManagedBLImage log = logger.get(__name__) +link_action_queue = queue.Queue() + + +def set_link_validity(link: bpy.types.NodeLink, validity: bool) -> None: + log.critical('Set %s validity to %s', str(link), str(validity)) + link.is_valid = validity + + #################### # - Cache Management #################### @@ -45,58 +55,93 @@ class DeltaNodeLinkCache(typ.TypedDict): class NodeLinkCache: - """A pointer-based cache of node links in a node tree. + """A volatile pointer-based cache of node links in a node tree. + + Warnings: + Everything here is **extremely** unsafe. + Even a single mistake **will** cause a use-after-free crash of Blender. + + Used perfectly, it allows for powerful features; anything less, and it's an epic liability. Attributes: - _node_tree: Reference to the owning node tree. - link_ptrs_as_links: - link_ptrs: Pointers (as in integer memory adresses) to `NodeLink`s. - link_ptrs_as_links: Map from pointers to actual `NodeLink`s. - link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the source of each `NodeLink`. - link_ptrs_from_sockets: Map from pointers to `NodeSocket`s, representing the destination of each `NodeLink`. + _node_tree: Reference to the node tree for which this cache is valid. + link_ptrs: Memory-address identifiers for all node links that currently exist in `_node_tree`. + link_ptrs_as_links: Mapping from pointers (integers) to actual `NodeLink` objects. + **WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this. + socket_ptrs: Memory-address identifiers for all sockets that currently exist in `_node_tree`. + socket_ptrs_as_sockets: Mapping from pointers (integers) to actual `NodeSocket` objects. + **WARNING**: If the pointer-referenced object no longer exists, then Blender **will crash immediately** upon attempting to use it. There is no way to mitigate this. + socket_ptr_refcount: The amount of links currently connected to a given socket pointer. + Used to drive the deletion of socket pointers using only knowledge about `link_ptr` removal. + link_ptrs_as_from_socket_ptrs: The pointer of the source socket, defined for every node link pointer. + link_ptrs_as_to_socket_ptrs: The pointer of the destination socket, defined for every node link pointer. """ def __init__(self, node_tree: bpy.types.NodeTree): - """Initialize the cache from a node tree. - - Parameters: - node_tree: The Blender node tree whose `NodeLink`s will be cached. - """ + """Defines and fills the cache from a live node tree.""" self._node_tree = node_tree - # Link PTR and PTR->REF self.link_ptrs: set[MemAddr] = set() self.link_ptrs_as_links: dict[MemAddr, bpy.types.NodeLink] = {} - # Socket PTR and PTR->REF self.socket_ptrs: set[MemAddr] = set() self.socket_ptrs_as_sockets: dict[MemAddr, bpy.types.NodeSocket] = {} self.socket_ptr_refcount: dict[MemAddr, int] = {} - # Link PTR -> Socket PTR self.link_ptrs_as_from_socket_ptrs: dict[MemAddr, MemAddr] = {} self.link_ptrs_as_to_socket_ptrs: dict[MemAddr, MemAddr] = {} + self.link_ptrs_invalid: set[MemAddr] = set() + # Fill Cache self.regenerate() def remove_link(self, link_ptr: MemAddr) -> None: - """Removes a link pointer from the cache, indicating that the link doesn't exist anymore. + """Reports a link as removed, causing it to be removed from the cache. + + This **must** be run whenever a node link is deleted. + **Failure to to so WILL result in segmentation fault** at an unknown future time. + + In particular, the following actions are taken: + - The entry in `self.link_ptrs_as_links` is deleted. + - Any entry in `self.link_ptrs_invalid` is deleted (if exists). Notes: - - **DOES NOT** remove PTR->REF dictionary entries - - Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`. - - This **must** be done whenever a node link is deleted. - - Failure to do so may result in a segmentation fault at arbitrary future time. + Invoking this method directly causes the removed node links to not be reported as "removed" by `NodeLinkCache.regenerate()`. + In some cases, this may be desirable, ex. for internal methods that shouldn't trip a `DataChanged` flow event. Parameters: - link_ptr: Pointer to remove from the cache. + link_ptr: The pointer (integer) to remove from the cache. + + Raises: + KeyError: If `link_ptr` is not a member of either `self.link_ptrs`, or of `self.link_ptrs_as_links`. """ self.link_ptrs.remove(link_ptr) self.link_ptrs_as_links.pop(link_ptr) + if link_ptr in self.link_ptrs_invalid: + self.link_ptrs_invalid.remove(link_ptr) + def remove_sockets_by_link_ptr(self, link_ptr: MemAddr) -> None: - """Removes a single pointer's reference to its from/to sockets.""" + """Deassociate from all sockets referenced by a link, respecting the socket pointer reference-count. + + The `NodeLinkCache` stores references to all socket pointers referenced by any link. + Since several links can be associated with each socket, we must keep a "reference count" per-socket. + When the "reference count" drops to zero, then there are no longer any `NodeLink`s that refer to it, and therefore it should be removed from the `NodeLinkCache`. + + This method facilitates that process by: + - Extracting (with removal) the from / to socket pointers associated with `link_ptr`. + - If the socket pointer has a reference count of `1`, then it is **completely removed**. + - If the socket pointer has a reference count of `>1`, then the reference count is decremented by `1`. + + Notes: + In general, this should be called together with `remove_link`. + However, in certain cases, this process also needs to happen by itself. + + Parameters: + link_ptr: The pointer (integer) to remove from the cache. + """ + # Remove Socket Pointers from_socket_ptr = self.link_ptrs_as_from_socket_ptrs.pop(link_ptr, None) to_socket_ptr = self.link_ptrs_as_to_socket_ptrs.pop(link_ptr, None) @@ -113,31 +158,40 @@ class NodeLinkCache: self.socket_ptr_refcount[socket_ptr] -= 1 def regenerate(self) -> DeltaNodeLinkCache: - """Regenerates the cache from the internally-linked node tree. + """Efficiently scans the internally referenced node tree to thoroughly update all attributes of this `NodeLinkCache`. Notes: - - This is designed to run within the `update()` invocation of the node tree. - - This should be a very fast function, since it is called so much. + This runs in a **very** hot loop, within the `update()` function of the node tree. + Anytime anything happens in the node tree, `update()` (and therefore this method) is called. + + Thus, performance is of the utmost importance. + Just a few microseconds too much may be amplified dozens of times over in practice, causing big stutters. """ # Compute All NodeLink Pointers + ## -> It can be very inefficient to do any full-scan of the node tree. + ## -> However, simply extracting the pointer: link ends up being fast. + ## -> This pattern seems to be the best we can do, efficiency-wise. all_link_ptrs_as_links = { link.as_pointer(): link for link in self._node_tree.links } all_link_ptrs = set(all_link_ptrs_as_links.keys()) # Compute Added/Removed Links + ## -> In essence, we've created a 'diff' here. + ## -> Set operations are fast, and expressive! added_link_ptrs = all_link_ptrs - self.link_ptrs removed_link_ptrs = self.link_ptrs - all_link_ptrs # Edge Case: 'from_socket' Reassignment - ## (Reverse engineered) When all: - ## - Created a new link between the same two nodes. - ## - Matching 'to_socket'. - ## - Non-matching 'from_socket' on the same node. - ## -> THEN the link_ptr will not change, but the from_socket ptr should. - if len(added_link_ptrs) == 0 and len(removed_link_ptrs) == 0: + ## (Reverse Engineered) When all are true: + ## - Created a new link between the same nodes as previous link. + ## - Matching 'to_socket' as the previous link. + ## - Non-matching 'from_socket', but on the same node. + ## -> THEN the link_ptr will not change, but the from_socket ptr does. + if not added_link_ptrs and not removed_link_ptrs: # Find the Link w/Reassigned 'from_socket' PTR - ## A bit of a performance hit from the search, but it's an edge case. + ## -> This isn't very fast, but the edge case isn't so common. + ## -> Comprehensions are still quite optimized. _link_ptr_as_from_socket_ptrs = { link_ptr: ( from_socket_ptr, @@ -149,9 +203,9 @@ class NodeLinkCache: } # Completely Remove the Old Link (w/Reassigned 'from_socket') - ## This effectively reclassifies the edge case as a normal 're-add'. + ## -> Casts the edge case to look like a typical 're-add'. for link_ptr in _link_ptr_as_from_socket_ptrs: - log.info( + log.debug( 'Edge-Case - "from_socket" Reassigned in NodeLink w/o New NodeLink Pointer: %s', link_ptr, ) @@ -159,21 +213,25 @@ class NodeLinkCache: self.remove_sockets_by_link_ptr(link_ptr) # Recompute Added/Removed Links - ## The algorithm will now detect an "added link". + ## -> Guide the usual algorithm to detect an "added link". added_link_ptrs = all_link_ptrs - self.link_ptrs removed_link_ptrs = self.link_ptrs - all_link_ptrs - # Shuffle Cache based on Change in Links - ## Remove Entries for Removed Pointers + # Delete Removed Links + ## -> NOTE: We leave dangling socket information on purpose. + ## -> This information will be used to ask for 'removal consent'. + ## -> To truly remove, must call 'remove_socket_by_link_ptr' later. for removed_link_ptr in removed_link_ptrs: self.remove_link(removed_link_ptr) - ## User must manually call 'remove_socket_by_link_ptr' later. - ## For now, leave dangling socket information by-link. - # Add New Link Pointers + # Create Added Links + ## -> First, simply concatenate the added link pointers. self.link_ptrs |= added_link_ptrs for link_ptr in added_link_ptrs: - # Add Link PTR->REF + # Create Pointer -> Reference Entry + ## -> This allows us to efficiently access the link by-pointer. + ## -> Doing so otherwise requires a full search. + ## -> **If link is deleted w/o report, access will cause crash**. new_link = all_link_ptrs_as_links[link_ptr] self.link_ptrs_as_links[link_ptr] = new_link @@ -183,34 +241,69 @@ class NodeLinkCache: to_socket = new_link.to_socket to_socket_ptr = to_socket.as_pointer() - # Add Socket PTR, PTR -> REF + # Add Socket Information for socket_ptr, bl_socket in zip( # noqa: B905 [from_socket_ptr, to_socket_ptr], [from_socket, to_socket], ): - # Increment RefCount of Socket PTR + # RefCount > 0: Increment RefCount of Socket PTR ## This happens if another link also uses the same socket. ## 1. An output socket links to several inputs. ## 2. A multi-input socket links from several inputs. if socket_ptr in self.socket_ptr_refcount: self.socket_ptr_refcount[socket_ptr] += 1 + + # RefCount == 0: Create Socket Pointer w/Reference + ## -> Also initialize the refcount for the socket pointer. else: - ## RefCount == 0: Add PTR, PTR -> REF self.socket_ptrs.add(socket_ptr) self.socket_ptrs_as_sockets[socket_ptr] = bl_socket self.socket_ptr_refcount[socket_ptr] = 1 - # Add Link PTR -> Socket PTR + # Add Entry from Link Pointer -> Socket Pointer self.link_ptrs_as_from_socket_ptrs[link_ptr] = from_socket_ptr self.link_ptrs_as_to_socket_ptrs[link_ptr] = to_socket_ptr return {'added': added_link_ptrs, 'removed': removed_link_ptrs} + def update_validity(self) -> DeltaNodeLinkCache: + """Query all cached links to determine whether they are valid.""" + self.link_ptrs_invalid = { + link_ptr for link_ptr, link in self.link_ptrs_as_links if not link.is_valid + } + + def report_validity(self, link_ptr: MemAddr, validity: bool) -> None: + """Report a link as invalid.""" + if validity and link_ptr in self.link_ptrs_invalid: + self.link_ptrs_invalid.remove(link_ptr) + elif not validity and link_ptr not in self.link_ptrs_invalid: + self.link_ptrs_invalid.add(link_ptr) + + def set_validities(self) -> None: + """Set the validity of links in the node tree according to the internal cache. + + Validity doesn't need to be removed, as update() automatically cleans up by default. + """ + for link in [ + link + for link_ptr, link in self.link_ptrs_as_links.items() + if link_ptr in self.link_ptrs_invalid + ]: + if link.is_valid: + link.is_valid = False + #################### # - Node Tree Definition #################### class MaxwellSimTree(bpy.types.NodeTree): + """Node tree containing a node-based program for design and analysis of Maxwell PDE simulations. + + Attributes: + is_active: Whether the node tree should be considered to be in a usable state, capable of updating Blender data. + In general, only one `MaxwellSimTree` should be active at a time. + """ + bl_idname = ct.TreeType.MaxwellSim.value bl_label = 'Maxwell Sim Editor' bl_icon = ct.Icon.SimNodeEditor @@ -219,63 +312,6 @@ class MaxwellSimTree(bpy.types.NodeTree): default=True, ) - #################### - # - Lock Methods - #################### - def unlock_all(self) -> None: - """Unlock all nodes in the node tree, making them editable.""" - log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label) - for node in self.nodes: - if node.type in ['REROUTE', 'FRAME']: - continue - node.locked = False - for bl_socket in [*node.inputs, *node.outputs]: - bl_socket.locked = False - - @contextlib.contextmanager - def replot(self) -> None: - self.is_currently_replotting = True - self.something_plotted = False - - try: - yield - finally: - self.is_currently_replotting = False - if not self.something_plotted: - ManagedBLImage.hide_preview() - - def report_show_plot(self, node: bpy.types.Node) -> None: - if hasattr(self, 'is_currently_replotting') and self.is_currently_replotting: - self.something_plotted = True - - @contextlib.contextmanager - def repreview_all(self) -> None: - all_nodes_with_preview_active = { - node.instance_id: node - for node in self.nodes - if node.type not in ['REROUTE', 'FRAME'] and node.preview_active - } - self.is_currently_repreviewing = True - self.newly_previewed_nodes = {} - - try: - yield - finally: - self.is_currently_repreviewing = False - for dangling_previewed_node in [ - node - for node_instance_id, node in all_nodes_with_preview_active.items() - if node_instance_id not in self.newly_previewed_nodes - ]: - dangling_previewed_node.preview_active = False - - def report_show_preview(self, node: bpy.types.Node) -> None: - if ( - hasattr(self, 'is_currently_repreviewing') - and self.is_currently_repreviewing - ): - self.newly_previewed_nodes[node.instance_id] = node - #################### # - Init Methods #################### @@ -290,7 +326,54 @@ class MaxwellSimTree(bpy.types.NodeTree): self.node_link_cache = NodeLinkCache(self) #################### - # - Update Methods + # - Lock Methods + #################### + def unlock_all(self) -> None: + """Unlock all nodes in the node tree, making them editable. + + Notes: + All `MaxwellSimNode`s have a `.locked` attribute, which prevents the entire UI from being modified. + + This method simply sets the `locked` attribute to `False` on all nodes. + """ + log.info('Unlocking All Nodes in NodeTree "%s"', self.bl_label) + for node in self.nodes: + if node.type in ['REROUTE', 'FRAME']: + continue + + # Unlock Node + if node.locked: + node.locked = False + + # Unlock Node Sockets + for bl_socket in [*node.inputs, *node.outputs]: + if bl_socket.locked: + bl_socket.locked = False + + #################### + # - Link Update Methods + #################### + def report_link_validity(self, link: bpy.types.NodeLink, validity: bool) -> None: + """Report that a particular `NodeLink` should be considered to be either valid or invalid. + + The `NodeLink.is_valid` attribute is generally (and automatically) used to indicate the detection of cycles in the node tree. + However, visually, it causes a very clear "error red" highlight to appear on the node link, which can extremely useful when determining the reasons behind unexpected outout. + + Notes: + Run by `MaxwellSimSocket` when a link should be shown to be "invalid". + """ + ## TODO: Doesn't quite work. + # log.debug( + # 'Reported Link Validity %s (is_valid=%s, from_socket=%s, to_socket=%s)', + # validity, + # link.is_valid, + # link.from_socket, + # link.to_socket, + # ) + # self.node_link_cache.report_validity(link.as_pointer(), validity) + + #################### + # - Node Update Methods #################### def on_node_removed(self, node: bpy.types.Node): """Run by `MaxwellSimNode.free()` when a node is being removed. @@ -327,32 +410,36 @@ class MaxwellSimTree(bpy.types.NodeTree): self.node_link_cache.remove_link(link_ptr) self.node_link_cache.remove_sockets_by_link_ptr(link_ptr) - def update(self) -> None: + def update(self) -> None: # noqa: PLR0912, C901 """Monitors all changes to the node tree, potentially responding with appropriate callbacks. Notes: - Run by Blender when "anything" changes in the node tree. - Responds to node link changes with callbacks, with the help of a performant node link cache. """ + # Perform Initial Load + ## -> Presume update() is run before the first link is altered. + ## -> Else, the first link of the session will not update caches. + ## -> We still remain slightly unsure of the exact semantics. + ## -> Therefore, self.on_load() is also called as a load_post handler. + if not hasattr(self, 'node_link_cache'): + self.on_load() + return + + # Register Validity Updater + ## -> They will be run after the update() method. + ## -> Between update() and set_validities, all is_valid=True are cleared. + ## -> Therefore, 'set_validities' only needs to set all is_valid=False. + bpy.app.timers.register(self.node_link_cache.set_validities) + + # Ignore Updates + ## -> Certain corrective processes require suppressing the next update. + ## -> Otherwise, link corrections may trigger some nasty recursions. if not hasattr(self, 'ignore_update'): self.ignore_update = False - if not hasattr(self, 'node_link_cache'): - self.on_load() - ## We presume update() is run before the first link is altered. - ## - Else, the first link of the session will not update caches. - ## - We remain slightly unsure of the semantics. - ## - Therefore, self.on_load() is also called as a load_post handler. - return - - # Ignore Update - ## Manually set to implement link corrections w/o recursion. - if self.ignore_update: - return - - # Compute Changes to Node Links + # Regenerate NodeLinkCache delta_links = self.node_link_cache.regenerate() - link_corrections = { 'to_remove': [], 'to_add': [], diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index 0c9c8d0..e036857 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -358,6 +358,11 @@ class ExtractDataNode(base.MaxwellSimNode): ## -> Those string labels explain the integer as ex. Ex, Ey, Hy. idx_labels = valid_monitor_attrs(sim_data, monitor_name) + # Extract Info + ## -> We only need the output symbol. + ## -> All labelled outputs have the same output SimSymbol. + info = extract_info(monitor_data, idx_labels[0]) + # Generate FuncFlow Per Index Label ## -> We extract each XArray as an attribute of monitor_data. ## -> We then bind its values into a unique func_flow. @@ -377,7 +382,8 @@ class ExtractDataNode(base.MaxwellSimNode): ## -> Then, 'compose_within' lets us stack them along axis=0. ## -> The "new" axis=0 is int-indexed axis w/idx_labels labels! return functools.reduce(lambda a, b: a | b, func_flows).compose_within( - enclosing_func=lambda data: jnp.stack(data, axis=0) + lambda data: jnp.stack(data, axis=0), + func_output=info.output, ) return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index c90dbe5..eddeff4 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -65,12 +65,12 @@ class FilterOperation(enum.StrEnum): FO = FilterOperation return { # Slice - FO.Slice: '=a[i:j]', - FO.SliceIdx: '≈a[v₁:v₂]', + FO.Slice: '≈a[v₁:v₂]', + FO.SliceIdx: '=a[i:j]', # Pin - FO.PinLen1: 'pinₐ', - FO.Pin: 'pinₐ ≈v', - FO.PinIdx: 'pinₐ =i', + FO.PinLen1: 'a[0] → a', + FO.Pin: 'a[v] ⇝ a', + FO.PinIdx: 'a[i] → a', # Reinterpret FO.Swap: 'a₁ ↔ a₂', }[value] @@ -517,6 +517,7 @@ class FilterMathNode(base.MaxwellSimNode): return lazy_func.compose_within( operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple), enclosing_func_args=operation.func_args, + enclosing_func_output=info.output, supports_jax=True, ) return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 7529a18..04765a8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -547,22 +547,30 @@ class MapMathNode(base.MaxwellSimNode): #################### @events.computes_output_socket( 'Expr', + # Loaded kind=ct.FlowKind.Func, props={'operation'}, input_sockets={'Expr'}, input_socket_kinds={ 'Expr': ct.FlowKind.Func, }, + output_sockets={'Expr'}, + output_socket_kinds={'Expr': ct.FlowKind.Info}, ) - def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal: - operation = props['operation'] + def compute_func( + self, props, input_sockets, output_sockets + ) -> ct.FuncFlow | ct.FlowSignal: expr = input_sockets['Expr'] + output_info = output_sockets['Expr'] has_expr = not ct.FlowSignal.check(expr) + has_output_info = not ct.FlowSignal.check(output_info) + operation = props['operation'] if has_expr and operation is not None: return expr.compose_within( operation.jax_func, + enclosing_func_output=output_info.output, supports_jax=True, ) return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py index 578be8f..f59bde4 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -146,7 +146,6 @@ class BinaryOperation(enum.StrEnum): outl = info_l.output outr = info_r.output match (outl.shape_len, outr.shape_len): - # match (ol.shape_len, info_r.output.shape_len): # Number | * ## Number | Number case (0, 0): @@ -154,15 +153,25 @@ class BinaryOperation(enum.StrEnum): BO.Add, BO.Sub, BO.Mul, - BO.Div, - BO.Pow, ] + + # Check Non-Zero Right Hand Side + ## -> Obviously, we can't ever divide by zero. + ## -> Sympy's assumptions system must always guarantee rhs != 0. + ## -> If it can't, then we simply don't expose division. + ## -> The is_zero assumption must be provided elsewhere. + ## -> NOTE: This may prevent some valid uses of division. + ## -> Watch out for "division is missing" bugs. + if info_r.output.is_nonzero: + ops.append(BO.Div) + if ( info_l.output.physical_type == spux.PhysicalType.Length and info_l.output.unit == info_r.output.unit ): ops += [BO.Atan2] - return ops + + return [*ops, BO.Pow] ## Number | Vector case (0, 1): @@ -336,7 +345,13 @@ class BinaryOperation(enum.StrEnum): # - InfoFlow Transform #################### def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): - """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.""" + """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression. + + Warnings: + `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r). + + If not, bad things will happen. + """ return info_l.operate_output( info_r, lambda a, b: self.sp_func([a, b]), @@ -479,29 +494,35 @@ class OperateMathNode(base.MaxwellSimNode): @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Func, + # Loaded props={'operation'}, input_sockets={'Expr L', 'Expr R'}, input_socket_kinds={ 'Expr L': ct.FlowKind.Func, 'Expr R': ct.FlowKind.Func, }, + output_sockets={'Expr'}, + output_socket_kinds={'Expr': ct.FlowKind.Info}, ) - def compose_func(self, props: dict, input_sockets: dict): + def compute_func(self, props, input_sockets, output_sockets): operation = props['operation'] if operation is None: return ct.FlowSignal.FlowPending expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] + output_info = output_sockets['Expr'] has_expr_l = not ct.FlowSignal.check(expr_l) has_expr_r = not ct.FlowSignal.check(expr_r) + has_output_info = not ct.FlowSignal.check(output_info) # Compute Jax Function ## -> The operation enum directly provides the appropriate function. - if has_expr_l and has_expr_r: + if has_expr_l and has_expr_r and has_output_info: return (expr_l | expr_r).compose_within( - enclosing_func=operation.jax_func, + operation.jax_func, + enclosing_func_output=output_info.output, supports_jax=True, ) return ct.FlowSignal.FlowPending @@ -520,6 +541,8 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_info(self, props, input_sockets) -> ct.InfoFlow: + BO = BinaryOperation + operation = props['operation'] info_l = input_sockets['Expr L'] info_r = input_sockets['Expr R'] @@ -533,7 +556,7 @@ class OperateMathNode(base.MaxwellSimNode): has_info_l and has_info_r and operation is not None - and operation in BinaryOperation.by_infos(info_l, info_r) + and operation in BO.by_infos(info_l, info_r) ): return operation.transform_infos(info_l, info_r) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py index f869326..6744f7d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py @@ -606,27 +606,32 @@ class TransformMathNode(base.MaxwellSimNode): input_socket_kinds={ 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}, }, + output_sockets={'Expr'}, + output_socket_kinds={'Expr': ct.FlowKind.Info}, ) - def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal: + def compute_func( + self, props, input_sockets, output_sockets + ) -> ct.FuncFlow | ct.FlowSignal: """Transform the input `InfoFlow` depending on the transform operation.""" TO = TransformOperation - operation = props['operation'] + lazy_func = input_sockets['Expr'][ct.FlowKind.Func] info = input_sockets['Expr'][ct.FlowKind.Info] + output_info = output_sockets['Expr'] has_info = not ct.FlowSignal.check(info) has_lazy_func = not ct.FlowSignal.check(lazy_func) + has_output_info = not ct.FlowSignal.check(output_info) - if operation is not None and has_lazy_func and has_info: - # Retrieve Properties + operation = props['operation'] + if operation is not None and has_lazy_func and has_info and has_output_info: dim = props['dim'] - - # Match Pattern by Operation match operation: case TO.FreqToVacWL | TO.VacWLToFreq | TO.FT1D | TO.InvFT1D: if dim is not None and info.has_idx_discrete(dim): return lazy_func.compose_within( operation.jax_func(axis=info.dim_axis(dim)), + enclosing_func_output=output_info.output, supports_jax=True, ) return ct.FlowSignal.FlowPending @@ -634,6 +639,7 @@ class TransformMathNode(base.MaxwellSimNode): case _: return lazy_func.compose_within( operation.jax_func(), + enclosing_func_output=output_info.output, supports_jax=True, ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index b901275..ad44ef8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -406,7 +406,7 @@ class VizNode(base.MaxwellSimNode): }, all_loose_input_sockets=True, ) - def compute_dummy_value(self, props, input_sockets, loose_input_sockets): + def compute_previews(self, props, input_sockets, loose_input_sockets): """Needed for the plot to regenerate in the viewer.""" return ct.PreviewsFlow(bl_image_name=props['sim_node_name']) @@ -433,7 +433,7 @@ class VizNode(base.MaxwellSimNode): def on_show_plot( self, managed_objs, props, input_sockets, loose_input_sockets ) -> None: - log.critical('Show Plot (too many times)') + log.debug('Show Plot') lazy_func = input_sockets['Expr'][ct.FlowKind.Func] info = input_sockets['Expr'][ct.FlowKind.Info] params = input_sockets['Expr'][ct.FlowKind.Params] @@ -456,6 +456,7 @@ class VizNode(base.MaxwellSimNode): sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols }, ) + ## TODO: CACHE entries that don't change, PLEASEEE # Match Viz Type & Perform Visualization ## -> Viz Target determines how to plot. diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index 6450d01..538cbe0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -207,12 +207,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): stop_propagation=True, ) def _on_sim_node_name_changed(self, props): - log.debug( - 'Changed Sim Node Name of a "%s" to "%s" (self=%s)', - self.bl_idname, - props['sim_node_name'], - str(self), - ) + # log.debug( + # 'Changed Sim Node Name of a "%s" to "%s" (self=%s)', + # self.bl_idname, + # props['sim_node_name'], + # str(self), + # ) # (Re)Construct Managed Objects ## -> Due to 'prev_name', the new MObjs will be renamed on construction @@ -360,27 +360,48 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): #################### # - Socket Management #################### - ## TODO: Check for namespace collisions in sockets to prevent silent errors def _prune_inactive_sockets(self): - """Remove all "inactive" sockets from the node. + """Remove all inactive sockets from the node, while only updating sockets that can be non-destructively updated. - A socket is considered "inactive" when it shouldn't be defined (per `self.active_socket_defs), but is present nonetheless. + The first step is easy: We determine, by-name, which sockets should no longer be defined, then remove them correctly. + + The second step is harder: When new sockets have overlapping names, should they be removed, or should they merely have some properties updated? + Removing and re-adding the same socket is an accurate, generally robust approach, but it comes with a big caveat: **Existing node links will be cut**, even when it might semantically make sense to simply alter the socket's properties, keeping the links. + + Different `bl_socket.socket_type`s can never be updated - they must be removed. + Otherwise, `SocketDef.compare(bl_socket)` allows us to granularly determine whether a particular `bl_socket` has changed with respect to the desired specification. + When the comparison is `False`, we can carefully utilize `SocketDef.init()` to re-initialize the socket, guaranteeing that the altered socket is up to the new specification. """ node_tree = self.id_data for direc in ['input', 'output']: - all_bl_sockets = self._bl_sockets(direc) - active_bl_socket_defs = self.active_socket_defs(direc) + bl_sockets = self._bl_sockets(direc) + active_socket_defs = self.active_socket_defs(direc) # Determine Sockets to Remove + ## -> Name: If the existing socket name isn't "active". + ## -> Type: If the existing socket_type != "active" SocketDef. bl_sockets_to_remove = [ bl_socket - for socket_name, bl_socket in all_bl_sockets.items() - if socket_name not in active_bl_socket_defs - or socket_name - in ( - self.loose_input_sockets - if direc == 'input' - else self.loose_output_sockets + for socket_name, bl_socket in bl_sockets.items() + if ( + socket_name not in active_socket_defs + or bl_socket.socket_type + is not active_socket_defs[socket_name].socket_type + ) + ] + + # Determine Sockets to Update + ## -> Name: If the existing socket name is "active". + ## -> Type: If the existing socket_type == "active" SocketDef. + ## -> Compare: If the existing socket differs from the SocketDef. + bl_sockets_to_update = [ + bl_socket + for socket_name, bl_socket in bl_sockets.items() + if ( + socket_name in active_socket_defs + and bl_socket.socket_type + is active_socket_defs[socket_name].socket_type + and not active_socket_defs[socket_name].compare(bl_socket) ) ] @@ -392,24 +413,25 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): ## -> The NodeLinkCache needs to be adjusted manually. node_tree.on_node_socket_removed(bl_socket) - # 2. Invalidate the input socket cache across all kinds. + # 2. Perform the removal using Blender's API. + ## -> Actually removes the socket. + bl_sockets.remove(bl_socket) + + # 3. Invalidate the input socket cache across all kinds. ## -> Prevents phantom values from remaining available. + ## -> Done after socket removal to protect from race condition. self._compute_input.invalidate( input_socket_name=bl_socket_name, kind=..., unit_system=..., ) - # 3. Perform the removal using Blender's API. - ## -> Actually removes the socket. - all_bl_sockets.remove(bl_socket) - if direc == 'input': # 4. Run all trigger-only `on_value_changed` callbacks. ## -> Runs any event methods that relied on the socket. ## -> Only methods that don't **require** the socket. - ## Trigger-Only: If method loads no socket data, it runs. - ## `optional`: If method optional-loads socket, it runs. + ## Only Trigger: If method loads no socket data, it runs. + ## Optional: If method optional-loads socket, it runs. triggered_event_methods = [ event_method for event_method in self.filtered_event_methods_by_event( @@ -419,32 +441,52 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): not in event_method.callback_info.must_load_sockets ] for event_method in triggered_event_methods: - log.critical( - '%s: Running %s', - self.sim_node_name, - str(event_method), - ) event_method(self) + # Update Sockets + for bl_socket in bl_sockets_to_update: + bl_socket_name = bl_socket.name + socket_def = active_socket_defs[bl_socket_name] + + # 1. Pretend to Initialize for the First Time + ## -> NOTE: The socket's caches will be completely regenerated. + ## -> NOTE: A full FlowKind update will occur, but only one. + bl_socket.is_initializing = True + socket_def.preinit(bl_socket) + socket_def.init(bl_socket) + socket_def.postinit(bl_socket) + + # 2. Re-Test Socket Capabilities + ## -> Factors influencing CapabilitiesFlow may have changed. + ## -> Therefore, we must re-test all link capabilities. + bl_socket.remove_invalidated_links() + + # 3. Invalidate the input socket cache across all kinds. + ## -> Prevents phantom values from remaining available. + self._compute_input.invalidate( + input_socket_name=bl_socket_name, + kind=..., + unit_system=..., + ) + def _add_new_active_sockets(self): """Add and initialize all "active" sockets that aren't on the node. Existing sockets within the given direction are not re-created. """ for direc in ['input', 'output']: - all_bl_sockets = self._bl_sockets(direc) - active_bl_socket_defs = self.active_socket_defs(direc) + bl_sockets = self._bl_sockets(direc) + active_socket_defs = self.active_socket_defs(direc) # Define BL Sockets created_sockets = {} - for socket_name, socket_def in active_bl_socket_defs.items(): + for socket_name, socket_def in active_socket_defs.items(): # Skip Existing Sockets - if socket_name in all_bl_sockets: + if socket_name in bl_sockets: continue # Create BL Socket from Socket - ## Set 'display_shape' from 'socket_shape' - all_bl_sockets.new( + bl_sockets.new( str(socket_def.socket_type.value), socket_name, ) @@ -454,9 +496,9 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): # Initialize Just-Created BL Sockets for bl_socket_name, socket_def in created_sockets.items(): - socket_def.preinit(all_bl_sockets[bl_socket_name]) - socket_def.init(all_bl_sockets[bl_socket_name]) - socket_def.postinit(all_bl_sockets[bl_socket_name]) + socket_def.preinit(bl_sockets[bl_socket_name]) + socket_def.init(bl_sockets[bl_socket_name]) + socket_def.postinit(bl_sockets[bl_socket_name]) # Invalidate Cached NoFlows self._compute_input.invalidate( @@ -637,9 +679,10 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): lambda a, b: a | b, [ self._compute_input( - socket, kind=ct.FlowKind.Previews, unit_system=None + socket_name, + kind=ct.FlowKind.Previews, ) - for socket in [bl_socket.name for bl_socket in self.inputs] + for socket_name in [bl_socket.name for bl_socket in self.inputs] ], ct.PreviewsFlow(), ) @@ -897,9 +940,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): ) altered_socket_kinds[dep_out_sckname].add(dep_out_kind) + # Clear Output Socket Cache(s) + ## -> We aggregate it manually, so it needs a special invl. + ## -> See self.compute_output() + if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds: + for out_sckname in self.outputs.keys(): # noqa: SIM118 + self.compute_output.invalidate( + output_socket_name=out_sckname, + kind=ct.FlowKind.Previews, + ) + altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews) + # Run Triggered Event Methods ## -> A triggered event method may request to stop propagation. - ## -> A triggered event method may request to stop propagation. stop_propagation = False triggered_event_methods = self.filtered_event_methods_by_event( event, (socket_name, prop_names, None) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py index fa56228..aa94c09 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py @@ -266,7 +266,7 @@ def event_decorator( # noqa: PLR0913 ) # Loose Sockets - ## Compute All Loose Input Sockets + ## -> Determined by the active_kind of each loose input socket. method_kw_args |= ( { 'loose_input_sockets': { diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py index 8aac265..dba473a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py @@ -29,6 +29,8 @@ from ... import base, events class ScientificConstantNode(base.MaxwellSimNode): + """A well-known constant usable as itself, or as a symbol.""" + node_type = ct.NodeType.ScientificConstant bl_label = 'Scientific Constant' @@ -88,6 +90,11 @@ class ScientificConstantNode(base.MaxwellSimNode): #################### # - UI #################### + def draw_label(self): + if self.sci_constant_str: + return f'Const: {self.sci_constant_str}' + return self.bl_label + def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: col.prop(self, self.blfields['sci_constant_str'], text='') @@ -156,6 +163,7 @@ class ScientificConstantNode(base.MaxwellSimNode): props={'sci_constant', 'sci_constant_sym'}, ) def compute_lazy_func(self, props) -> typ.Any: + """Simple `FuncFlow` that computes the symbol value, with output units tracked correctly.""" sci_constant = props['sci_constant'] sci_constant_sym = props['sci_constant_sym'] @@ -165,6 +173,7 @@ class ScientificConstantNode(base.MaxwellSimNode): [sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax' ), func_args=[sci_constant_sym], + func_output=sci_constant_sym, supports_jax=True, ) return ct.FlowSignal.FlowPending @@ -175,6 +184,7 @@ class ScientificConstantNode(base.MaxwellSimNode): props={'sci_constant_sym'}, ) def compute_info(self, props: dict) -> typ.Any: + """Simple `FuncFlow` that computes the symbol value, with output units tracked correctly.""" sci_constant_sym = props['sci_constant_sym'] if sci_constant_sym is not None: @@ -193,8 +203,12 @@ class ScientificConstantNode(base.MaxwellSimNode): if sci_constant is not None and sci_constant_sym is not None: return ct.ParamsFlow( arg_targets=[sci_constant_sym], - func_args=[sci_constant], - is_differentiable=True, + func_args=[sci_constant_sym.sp_symbol], + symbols={sci_constant_sym}, + ).realize_partial( + { + sci_constant_sym: sci_constant, + } ) return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py index 240f16e..9dc93bd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py @@ -216,10 +216,11 @@ class SymbolConstantNode(base.MaxwellSimNode): props={'symbol'}, ) def compute_lazy_func(self, props) -> typ.Any: - sp_sym = props['symbol'].sp_symbol + sym = props['symbol'] return ct.FuncFlow( - func=sp.lambdify(sp_sym, sp_sym, 'jax'), - func_args=[sp_sym], + func=sp.lambdify(sym.sp_symbol_matsym, sym.sp_symbol_matsym, 'jax'), + func_args=[sym], + func_output=sym, supports_jax=True, ) @@ -235,6 +236,7 @@ class SymbolConstantNode(base.MaxwellSimNode): ) def compute_info(self, props) -> typ.Any: return ct.InfoFlow( + dims={props['symbol']: None}, output=props['symbol'], ) @@ -251,9 +253,6 @@ class SymbolConstantNode(base.MaxwellSimNode): arg_targets=[sym], func_args=[sym.sp_symbol], symbols={sym}, - is_differentiable=( - sym.mathtype in [spux.MathType.Real, spux.MathType.Complex] - ), ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py index 3d5af11..252dbed 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py @@ -198,9 +198,10 @@ class DataFileImporterNode(base.MaxwellSimNode): 'Expr', kind=ct.FlowKind.Func, # Loaded + props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}, input_sockets={'File Path'}, ) - def compute_func(self, input_sockets) -> td.Simulation: + def compute_func(self, props, input_sockets) -> td.Simulation: """Declare a lazy, composable function that returns the loaded data. Returns: @@ -209,6 +210,12 @@ class DataFileImporterNode(base.MaxwellSimNode): file_path = input_sockets['File Path'] has_file_path = not ct.FlowSignal.check(file_path) + func_output = sim_symbols.SimSymbol( + sym_name=props['output_name'], + mathtype=props['output_mathtype'], + physical_type=props['output_physical_type'], + unit=props['output_unit'], + ) if has_file_path and file_path is not None: data_file_format = ct.DataFileFormat.from_path(file_path) if data_file_format is not None: @@ -217,13 +224,18 @@ class DataFileImporterNode(base.MaxwellSimNode): if data_file_format.loader_is_jax_compatible: return ct.FuncFlow( func=lambda: data_file_format.loader(file_path), + func_output=func_output, supports_jax=True, ) # No Jax Compatibility: Eager Data Loading ## -> Load the data now and bind it. data = data_file_format.loader(file_path) - return ct.FuncFlow(func=lambda: data, supports_jax=True) + return ct.FuncFlow( + func=lambda: data, + func_output=func_output, + supports_jax=True, + ) return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py index 9198f6a..585793b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py @@ -86,24 +86,47 @@ class ViewerNode(base.MaxwellSimNode): # - Properties: Computed FlowKinds #################### @events.on_value_changed( - socket_name='Any', + # Trigger + prop_name='console_print_kind', + # Loaded + props={'auto_expr', 'console_print_kind'}, ) - def on_input_changed(self) -> None: + def on_print_kind_changed(self, props) -> None: + self.inputs['Any'].active_kind = props['console_print_kind'] + + if props['auto_expr']: + setattr( + self, + 'input_' + props['console_print_kind'].property_name, + bl_cache.Signal.InvalidateCache, + ) + + @events.on_value_changed( + # Trugger + socket_name='Any', + # Loaded + props={'auto_expr', 'console_print_kind'}, + ) + def on_input_changed(self, props) -> None: """Lightweight invalidator, which invalidates the more specific `cached_bl_property` used to determine when something ex. plot-related has changed. Calls `get_flow`, which will be called again when regenerating the `cached_bl_property`s. This **does not** call the flow twice, as `self._compute_input()` will be cached the first time. """ - for flow_kind in list(ct.FlowKind): - flow = self.get_flow( - flow_kind, always_load=flow_kind is ct.FlowKind.Previews + # Invalidate PreviewsFlow + setattr( + self, + 'input_' + ct.FlowKind.Previews.property_name, + bl_cache.Signal.InvalidateCache, + ) + + # Invalidate PreviewsFlow + if props['auto_expr']: + setattr( + self, + 'input_' + props['console_print_kind'].property_name, + bl_cache.Signal.InvalidateCache, ) - if flow is not None: - setattr( - self, - 'input_' + flow_kind.property_name, - bl_cache.Signal.InvalidateCache, - ) @bl_cache.cached_bl_property(depends_on={'auto_expr'}) def input_capabilities(self) -> ct.CapabilitiesFlow | None: diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py index ef63e0a..a0a3134 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py @@ -14,8 +14,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import functools import typing as typ +import bpy import sympy as sp from blender_maxwell.utils import bl_cache @@ -26,6 +28,8 @@ from .. import base, events class CombineNode(base.MaxwellSimNode): + """Combine single objects (ex. Source, Monitor, Structure) into a list.""" + node_type = ct.NodeType.Combine bl_label = 'Combine' @@ -33,112 +37,222 @@ class CombineNode(base.MaxwellSimNode): # - Sockets #################### input_socket_sets: typ.ClassVar = { - 'Maxwell Sources': {}, - 'Maxwell Structures': {}, - 'Maxwell Monitors': {}, + 'Sources': {}, + 'Structures': {}, + 'Monitors': {}, } output_socket_sets: typ.ClassVar = { - 'Maxwell Sources': { + 'Sources': { 'Sources': sockets.MaxwellSourceSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), }, - 'Maxwell Structures': { + 'Structures': { 'Structures': sockets.MaxwellStructureSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), }, - 'Maxwell Monitors': { + 'Monitors': { 'Monitors': sockets.MaxwellMonitorSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), }, } #################### - # - Draw + # - Properties #################### - amount: int = bl_cache.BLField(2, abs_min=1, prop_ui=True) + concatenate_first: bool = bl_cache.BLField(False) + value_or_func: ct.FlowKind = bl_cache.BLField( + enum_cb=lambda self, _: self._value_or_func(), + ) + + def _value_or_func(self): + return [ + flow_kind.bl_enum_element(i) + for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func]) + ] #################### # - Draw #################### - def draw_props(self, context, layout): - layout.prop(self, self.blfields['amount'], text='') + def draw_props(self, _, layout: bpy.types.UILayout): + layout.prop(self, self.blfields['value_or_func'], text='') + + if self.value_or_func is ct.FlowKind.Value: + layout.prop( + self, + self.blfields['concatenate_first'], + text='Concatenate', + toggle=True, + ) #################### # - Events #################### @events.on_value_changed( - # Trigger - prop_name={'active_socket_set', 'amount'}, - props={'active_socket_set', 'amount'}, + any_loose_input_socket=True, + prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'}, run_on_init=True, + # Loaded + props={'active_socket_set', 'concatenate_first', 'value_or_func'}, ) - def on_inputs_changed(self, props): - if props['active_socket_set'] == 'Maxwell Sources': - if ( - not self.loose_input_sockets - or not next(iter(self.loose_input_sockets)).startswith('Source') - or len(self.loose_input_sockets) != props['amount'] - ): - self.loose_input_sockets = { - f'Source #{i}': sockets.MaxwellSourceSocketDef() - for i in range(props['amount']) - } + def on_inputs_changed(self, props) -> None: + """Always create one extra loose input socket.""" + active_socket_set = props['active_socket_set'] - elif props['active_socket_set'] == 'Maxwell Structures': - if ( - not self.loose_input_sockets - or not next(iter(self.loose_input_sockets)).startswith('Structure') - or len(self.loose_input_sockets) != props['amount'] - ): - self.loose_input_sockets = { - f'Structure #{i}': sockets.MaxwellStructureSocketDef() - for i in range(props['amount']) - } - elif props['active_socket_set'] == 'Maxwell Monitors': - if ( - not self.loose_input_sockets - or not next(iter(self.loose_input_sockets)).startswith('Monitor') - or len(self.loose_input_sockets) != props['amount'] - ): - self.loose_input_sockets = { - f'Monitor #{i}': sockets.MaxwellMonitorSocketDef() - for i in range(props['amount']) - } - elif self.loose_input_sockets: - self.loose_input_sockets = {} + # Deduce SocketDef + ## -> Cheat by retrieving the class from the output sockets. + SocketDef = self.output_socket_sets[active_socket_set][ + active_socket_set + ].__class__ + + # Deduce Current "Filled" + ## -> The first linked socket from the end bounds the "filled" region. + ## -> The length of that region, plus one, will be the new amount. + reverse_linked_idxs = [ + i + for i, bl_socket in enumerate(reversed(self.inputs.values())) + if bl_socket.is_linked + ] + current_filled = len(self.inputs) - ( + reverse_linked_idxs[0] if reverse_linked_idxs else len(self.inputs) + ) + new_amount = current_filled + 1 + + # Deduce SocketDef | Current Amount + concatenate_first = props['concatenate_first'] + flow_kind = props['value_or_func'] + + self.loose_input_sockets = { + '#0': SocketDef( + active_kind=flow_kind + if flow_kind is ct.FlowKind.Func or not concatenate_first + else ct.FlowKind.Array + ) + } | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)} #################### - # - Output Socket Computation + # - FlowKind.Array|Func + #################### + def compute_combined( + self, + loose_input_sockets, + input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func], + output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func], + ) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal: + """Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s. + + If there is no output, or the flows aren't compatible, return `FlowPending`. + """ + match (input_flow_kind, output_flow_kind): + case (ct.FlowKind.Value, ct.FlowKind.Array): + value_flows = [ + inp + for inp in loose_input_sockets.values() + if not ct.FlowSignal.check(inp) + ] + if value_flows: + return value_flows + return ct.FlowSignal.FlowPending + + case (ct.FlowKind.Func, ct.FlowKind.Func): + func_flows = [ + inp + for inp in loose_input_sockets.values() + if not ct.FlowSignal.check(inp) + ] + if func_flows: + return functools.reduce( + lambda a, b: a | b, + func_flows, + ) + return ct.FlowSignal.FlowPending + + return ct.FlowSignal.FlowPending + + #################### + # - Output: Sources #################### @events.computes_output_socket( 'Sources', kind=ct.FlowKind.Array, all_loose_input_sockets=True, - props={'amount'}, + props={'value_or_func'}, ) - def compute_sources(self, loose_input_sockets, props) -> sp.Expr: - return [loose_input_sockets[f'Source #{i}'] for i in range(props['amount'])] + def compute_sources_array( + self, props, loose_input_sockets + ) -> list[typ.Any] | ct.FlowSignal: + """Compute sources.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Array + ) + @events.computes_output_socket( + 'Sources', + kind=ct.FlowKind.Func, + all_loose_input_sockets=True, + props={'value_or_func'}, + ) + def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]: + """Compute (lazy) sources.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Func + ) + + #################### + # - Output: Structures + #################### @events.computes_output_socket( 'Structures', kind=ct.FlowKind.Array, all_loose_input_sockets=True, - props={'amount'}, + props={'value_or_func'}, ) - def compute_structures(self, loose_input_sockets, props) -> sp.Expr: - return [loose_input_sockets[f'Structure #{i}'] for i in range(props['amount'])] + def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr: + """Compute structures.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Array + ) + @events.computes_output_socket( + 'Structures', + kind=ct.FlowKind.Func, + all_loose_input_sockets=True, + props={'value_or_func'}, + ) + def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]: + """Compute (lazy) structures.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Func + ) + + #################### + # - Output: Monitors + #################### @events.computes_output_socket( 'Monitors', kind=ct.FlowKind.Array, all_loose_input_sockets=True, - props={'amount'}, + props={'value_or_func'}, ) - def compute_monitors(self, loose_input_sockets, props) -> sp.Expr: - return [loose_input_sockets[f'Monitor #{i}'] for i in range(props['amount'])] + def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr: + """Compute monitors.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Array + ) + + @events.computes_output_socket( + 'Monitors', + kind=ct.FlowKind.Func, + all_loose_input_sockets=True, + props={'value_or_func'}, + ) + def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]: + """Compute (lazy) monitors.""" + return self.compute_combined( + loose_input_sockets, props['value_or_func'], ct.FlowKind.Func + ) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py index 5a1901c..da49075 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py @@ -14,17 +14,26 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +"""Implements `FDTDSimNode`.""" + import typing as typ -import sympy as sp +import bpy import tidy3d as td +import tidy3d.plugins.adjoint as tdadj + +from blender_maxwell.utils import bl_cache, logger from ... import contracts as ct from ... import sockets from .. import base, events +log = logger.get(__name__) + class FDTDSimNode(base.MaxwellSimNode): + """Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration.""" + node_type = ct.NodeType.FDTDSim bl_label = 'FDTD Simulation' @@ -35,51 +44,255 @@ class FDTDSimNode(base.MaxwellSimNode): 'BCs': sockets.MaxwellBoundCondsSocketDef(), 'Domain': sockets.MaxwellSimDomainSocketDef(), 'Sources': sockets.MaxwellSourceSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), 'Structures': sockets.MaxwellStructureSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), 'Monitors': sockets.MaxwellMonitorSocketDef( - is_list=True, + active_kind=ct.FlowKind.Array, ), } - output_sockets: typ.ClassVar = { - 'Sim': sockets.MaxwellFDTDSimSocketDef(), + output_socket_sets: typ.ClassVar = { + 'Single': { + 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value), + }, + 'Batch': { + 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array), + }, + 'Lazy': { + 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func), + }, } #################### - # - Output Socket Computation + # - Properties + #################### + differentiable: bool = bl_cache.BLField(False) + + #################### + # - UI + #################### + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout): + layout.prop( + self, + self.blfields['differentiable'], + text='Differentiable', + toggle=True, + ) + + #################### + # - Events + #################### + @events.on_value_changed( + # Trigger + socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'}, + run_on_init=True, + # Loaded + props={'active_socket_set'}, + output_sockets={'Sim'}, + output_socket_kinds={'Sim': ct.FlowKind.Params}, + ) + def on_any_changed(self, props, output_sockets) -> None: + """Create loose input sockets.""" + params = output_sockets['Sim'] + has_params = not ct.FlowSignal.check(params) + + # Declare Loose Sockets that Realize Symbols + ## -> This happens if Params contains not-yet-realized symbols. + active_socket_set = props['active_socket_set'] + if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols: + if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}: + self.loose_input_sockets = { + sym.name: sockets.ExprSocketDef( + **( + expr_info + | { + 'active_kind': ct.FlowKind.Value, + 'use_value_range_swapper': ( + active_socket_set == 'Value' + ), + } + ) + ) + for sym, expr_info in params.sym_expr_infos.items() + } + + elif self.loose_input_sockets: + self.loose_input_sockets = {} + + #################### + # - FlowKind.Value #################### @events.computes_output_socket( 'Sim', kind=ct.FlowKind.Value, + # Loaded + props={'differentiable'}, input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'}, input_socket_kinds={ 'Sources': ct.FlowKind.Array, 'Structures': ct.FlowKind.Array, - 'Domain': ct.FlowKind.Value, - 'BCs': ct.FlowKind.Value, 'Monitors': ct.FlowKind.Array, }, + output_sockets={'Sim'}, + output_socket_kinds={'Sim': ct.FlowKind.Params}, ) - def compute_fdtd_sim(self, input_sockets: dict) -> sp.Expr: - if any(ct.FlowSignal.check(inp) for inp in input_sockets): - return ct.FlowSignal.FlowPending - + def compute_fdtd_sim_value( + self, props, input_sockets, output_sockets + ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal: + """Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable.""" sim_domain = input_sockets['Domain'] sources = input_sockets['Sources'] structures = input_sockets['Structures'] bounds = input_sockets['BCs'] monitors = input_sockets['Monitors'] - return td.Simulation( - **sim_domain, - structures=structures, - sources=sources, - monitors=monitors, - boundary_spec=bounds, - ) - ## TODO: Visualize the boundary conditions on top of the sim domain + output_params = output_sockets['Sim'] + + has_sim_domain = not ct.FlowSignal.check(sim_domain) + has_sources = not ct.FlowSignal.check(sources) + has_structures = not ct.FlowSignal.check(structures) + has_bounds = not ct.FlowSignal.check(bounds) + has_monitors = not ct.FlowSignal.check(monitors) + has_output_params = not ct.FlowSignal.check(output_params) + + differentiable = props['differentiable'] + if ( + has_sim_domain + and has_sources + and has_structures + and has_bounds + and has_monitors + and has_output_params + and not differentiable + ): + return td.Simulation( + **sim_domain, + sources=sources, + structures=structures, + boundary_spec=bounds, + monitors=monitors, + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Func + #################### + @events.computes_output_socket( + 'Sim', + kind=ct.FlowKind.Func, + # Loaded + props={'differentiable'}, + input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'}, + input_socket_kinds={ + 'Sources': ct.FlowKind.Func, + 'Structures': ct.FlowKind.Func, + 'Monitors': ct.FlowKind.Func, + }, + output_sockets={'Sim'}, + output_socket_kinds={'Sim': ct.FlowKind.Params}, + ) + def compute_fdtd_sim_func( + self, props, input_sockets, output_sockets + ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal: + """Compute a single simulation, given that all inputs are non-symbolic.""" + sim_domain = input_sockets['Domain'] + sources = input_sockets['Sources'] + structures = input_sockets['Structures'] + bounds = input_sockets['BCs'] + monitors = input_sockets['Monitors'] + output_params = output_sockets['Sim'] + + has_sim_domain = not ct.FlowSignal.check(sim_domain) + has_sources = not ct.FlowSignal.check(sources) + has_structures = not ct.FlowSignal.check(structures) + has_bounds = not ct.FlowSignal.check(bounds) + has_monitors = not ct.FlowSignal.check(monitors) + has_output_params = not ct.FlowSignal.check(output_params) + + if ( + has_sim_domain + and has_sources + and has_structures + and has_bounds + and has_monitors + and has_output_params + ): + differentiable = props['differentiable'] + if differentiable: + return ( + sim_domain | sources | structures | bounds | monitors + ).compose_within( + enclosing_func=lambda els: tdadj.JaxSimulation( + **els[0], + sources=els[1], + structures=els[2]['static'], + input_structures=els[2]['differentiable'], + boundary_spec=els[3], + monitors=els[4]['static'], + output_monitors=els[4]['differentiable'], + ), + supports_jax=True, + ) + return ( + sim_domain | sources | structures | bounds | monitors + ).compose_within( + enclosing_func=lambda els: td.Simulation( + **els[0], + sources=els[1], + structures=els[2], + boundary_spec=els[3], + monitors=els[4], + ), + supports_jax=False, + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Params + #################### + @events.computes_output_socket( + 'Sim', + kind=ct.FlowKind.Params, + # Loaded + props={'differentiable'}, + input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'}, + input_socket_kinds={ + 'Sources': ct.FlowKind.Params, + 'Structures': ct.FlowKind.Params, + 'Monitors': ct.FlowKind.Params, + }, + ) + def compute_fdtd_sim_params( + self, props, input_sockets + ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal: + """Compute a single simulation, given that all inputs are non-symbolic.""" + sim_domain = input_sockets['Domain'] + sources = input_sockets['Sources'] + structures = input_sockets['Structures'] + bounds = input_sockets['BCs'] + monitors = input_sockets['Monitors'] + + has_sim_domain = not ct.FlowSignal.check(sim_domain) + has_sources = not ct.FlowSignal.check(sources) + has_structures = not ct.FlowSignal.check(structures) + has_bounds = not ct.FlowSignal.check(bounds) + has_monitors = not ct.FlowSignal.check(monitors) + + if ( + has_sim_domain + and has_sources + and has_structures + and has_bounds + and has_monitors + ): + # Determine Differentiable Match + ## -> 'structures' is diff when **any** are diff. + ## -> 'monitors' is also diff when **any** are diff. + ## -> Only parameters through diff structs can be diff'ed by. + ## -> Similarly, only diff monitors will have gradients computed. + return sim_domain | sources | structures | bounds | monitors + return ct.FlowSignal.FlowPending #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py index 66a0696..4f8a0bb 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py @@ -14,6 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +"""Implements `SimDomainNode`.""" + import typing as typ import sympy as sp @@ -31,6 +33,8 @@ log = logger.get(__name__) class SimDomainNode(base.MaxwellSimNode): + """The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium.""" + node_type = ct.NodeType.SimDomain bl_label = 'Sim Domain' use_sim_node_name = True @@ -69,26 +73,109 @@ class SimDomainNode(base.MaxwellSimNode): } #################### - # - Outputs + # - FlowKind.Value #################### @events.computes_output_socket( 'Domain', + kind=ct.FlowKind.Value, + # Loaded + output_sockets={'Domain'}, + output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}}, + ) + def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute the particular value of the simulation domain from strictly non-symbolic inputs.""" + output_func = output_sockets['Domain'][ct.FlowKind.Func] + output_params = output_sockets['Domain'][ct.FlowKind.Params] + + has_output_func = not ct.FlowSignal.check(output_func) + has_output_params = not ct.FlowSignal.check(output_params) + + if has_output_func and has_output_params and not output_params.symbols: + return output_func.realize(output_params) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Func + #################### + @events.computes_output_socket( + 'Domain', + kind=ct.FlowKind.Func, + # Loaded input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'}, - unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, - scale_input_sockets={ - 'Duration': 'Tidy3DUnits', - 'Center': 'Tidy3DUnits', - 'Size': 'Tidy3DUnits', + input_socket_kinds={ + 'Duration': ct.FlowKind.Func, + 'Center': ct.FlowKind.Func, + 'Size': ct.FlowKind.Func, + 'Grid': ct.FlowKind.Func, + 'Ambient Medium': ct.FlowKind.Func, }, ) - def compute_domain(self, input_sockets, unit_systems) -> sp.Expr: - return { - 'run_time': input_sockets['Duration'], - 'center': input_sockets['Center'], - 'size': input_sockets['Size'], - 'grid_spec': input_sockets['Grid'], - 'medium': input_sockets['Ambient Medium'], - } + def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute the particular value of the simulation domain from strictly non-symbolic inputs.""" + duration = input_sockets['Duration'] + center = input_sockets['Center'] + size = input_sockets['Size'] + grid = input_sockets['Grid'] + medium = input_sockets['Ambient Medium'] + + has_duration = not ct.FlowSignal.check(duration) + has_center = not ct.FlowSignal.check(center) + has_size = not ct.FlowSignal.check(size) + has_grid = not ct.FlowSignal.check(grid) + has_medium = not ct.FlowSignal.check(medium) + + if has_duration and has_center and has_size and has_grid and has_medium: + return ( + duration.scale_to_unit_system(ct.UNITS_TIDY3D) + | center.scale_to_unit_system(ct.UNITS_TIDY3D) + | size.scale_to_unit_system(ct.UNITS_TIDY3D) + | grid + | medium + ).compose_within( + enclosing_func=lambda els: { + 'run_time': els[0], + 'center': tuple(els[1].flatten()), + 'size': tuple(els[2].flatten()), + 'grid_spec': els[3], + 'medium': els[4], + }, + supports_jax=False, + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Params + #################### + @events.computes_output_socket( + 'Domain', + kind=ct.FlowKind.Params, + # Loaded + input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'}, + input_socket_kinds={ + 'Duration': ct.FlowKind.Params, + 'Center': ct.FlowKind.Params, + 'Size': ct.FlowKind.Params, + 'Grid': ct.FlowKind.Params, + 'Ambient Medium': ct.FlowKind.Params, + }, + ) + def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs.""" + duration = input_sockets['Duration'] + center = input_sockets['Center'] + size = input_sockets['Size'] + grid = input_sockets['Grid'] + medium = input_sockets['Ambient Medium'] + + has_duration = not ct.FlowSignal.check(duration) + has_center = not ct.FlowSignal.check(center) + has_size = not ct.FlowSignal.check(size) + has_grid = not ct.FlowSignal.check(grid) + has_medium = not ct.FlowSignal.check(medium) + + if has_duration and has_center and has_size and has_grid and has_medium: + return duration | center | size | grid | medium + return ct.FlowSignal.FlowPending #################### # - Preview @@ -100,38 +187,40 @@ class SimDomainNode(base.MaxwellSimNode): props={'sim_node_name'}, ) def compute_previews(self, props): + """Mark the managed preview object for preview when `Domain` is linked to a viewer.""" return ct.PreviewsFlow(bl_object_names={props['sim_node_name']}) @events.on_value_changed( - ## Trigger + # Trigger socket_name={'Center', 'Size'}, run_on_init=True, # Loaded input_sockets={'Center', 'Size'}, managed_objs={'modifier'}, - unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, - scale_input_sockets={ - 'Center': 'BlenderUnits', - }, + output_sockets={'Domain'}, + output_socket_kinds={'Domain': ct.FlowKind.Params}, ) - def on_input_changed( - self, - managed_objs, - input_sockets, - unit_systems, - ): - # Push Loose Input Values to GeoNodes Modifier - managed_objs['modifier'].bl_modifier( - 'NODES', - { - 'node_group': import_geonodes(GeoNodes.SimulationSimDomain), - 'unit_system': unit_systems['BlenderUnits'], - 'inputs': { - 'Size': input_sockets['Size'], + def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None: + """Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols.""" + output_params = output_sockets['Domain'] + center = input_sockets['Center'] + + has_output_params = not ct.FlowSignal.check(output_params) + has_center = not ct.FlowSignal.check(center) + + if has_center and has_output_params and not output_params.symbols: + # Push Loose Input Values to GeoNodes Modifier + managed_objs['modifier'].bl_modifier( + 'NODES', + { + 'node_group': import_geonodes(GeoNodes.SimulationSimDomain), + 'unit_system': ct.UNITS_BLENDER, + 'inputs': { + 'Size': input_sockets['Size'], + }, }, - }, - location=input_sockets['Center'], - ) + location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER), + ) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py index ad438bc..dad848f 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py @@ -71,35 +71,129 @@ class PointDipoleSourceNode(base.MaxwellSimNode): layout.prop(self, self.blfields['pol_axis'], expand=True) #################### - # - Outputs + # - FlowKind.Value #################### @events.computes_output_socket( 'Source', - input_sockets={'Temporal Shape', 'Center', 'Interpolate'}, + # Loaded props={'pol_axis'}, - unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, - scale_input_sockets={ - 'Center': 'Tidy3DUnits', + input_sockets={'Temporal Shape', 'Center', 'Interpolate'}, + output_sockets={'Source'}, + output_socket_kinds={'Source': ct.FlowKind.Params}, + ) + def compute_source_value( + self, input_sockets, props, output_sockets + ) -> td.PointDipole | ct.FlowSignal: + """Compute the point dipole source, given that all inputs are non-symbolic.""" + temporal_shape = input_sockets['Temporal Shape'] + center = input_sockets['Center'] + interpolate = input_sockets['Interpolate'] + output_params = output_sockets['Source'] + + has_temporal_shape = not ct.FlowSignal.check(temporal_shape) + has_center = not ct.FlowSignal.check(center) + has_interpolate = not ct.FlowSignal.check(interpolate) + has_output_params = not ct.FlowSignal.check(output_params) + + if ( + has_temporal_shape + and has_center + and has_interpolate + and has_output_params + and not output_params.symbols + ): + pol_axis = { + ct.SimSpaceAxis.X: 'Ex', + ct.SimSpaceAxis.Y: 'Ey', + ct.SimSpaceAxis.Z: 'Ez', + }[props['pol_axis']] + ## TODO: Need Hx, Hy, Hz too? + + return td.PointDipole( + center=spux.convert_to_unit_system(center, ct.UNITS_TIDY3D), + source_time=temporal_shape, + interpolate=interpolate, + polarization=pol_axis, + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Func + #################### + @events.computes_output_socket( + 'Source', + kind=ct.FlowKind.Func, + # Loaded + props={'pol_axis'}, + input_sockets={'Temporal Shape', 'Center', 'Interpolate'}, + input_socket_kinds={ + 'Temporal Shape': ct.FlowKind.Func, + 'Center': ct.FlowKind.Func, + 'Interpolate': ct.FlowKind.Func, + }, + output_sockets={'Source'}, + output_socket_kinds={'Source': ct.FlowKind.Params}, + ) + def compute_source_func(self, props, input_sockets, output_sockets) -> td.Box: + """Compute a lazy function for the point dipole source.""" + center = input_sockets['Center'] + temporal_shape = input_sockets['Temporal Shape'] + interpolate = input_sockets['Interpolate'] + output_params = output_sockets['Source'] + + has_center = not ct.FlowSignal.check(center) + has_temporal_shape = not ct.FlowSignal.check(temporal_shape) + has_interpolate = not ct.FlowSignal.check(interpolate) + has_output_params = not ct.FlowSignal.check(output_params) + + if has_temporal_shape and has_center and has_interpolate and has_output_params: + pol_axis = { + ct.SimSpaceAxis.X: 'Ex', + ct.SimSpaceAxis.Y: 'Ey', + ct.SimSpaceAxis.Z: 'Ez', + }[props['pol_axis']] + ## TODO: Need Hx, Hy, Hz too? + + return (center | temporal_shape | interpolate).compose_within( + enclosing_func=lambda els: td.PointDipole( + center=els[0], + source_time=els[1], + interpolate=els[2], + polarization=pol_axis, + ) + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Params + #################### + @events.computes_output_socket( + 'Source', + kind=ct.FlowKind.Params, + # Loaded + input_sockets={'Temporal Shape', 'Center', 'Interpolate'}, + input_socket_kinds={ + 'Temporal Shape': ct.FlowKind.Params, + 'Center': ct.FlowKind.Params, + 'Interpolate': ct.FlowKind.Params, }, ) - def compute_source( + def compute_params( self, - input_sockets: dict[str, typ.Any], - props: dict[str, typ.Any], - unit_systems: dict, - ) -> td.PointDipole: - pol_axis = { - ct.SimSpaceAxis.X: 'Ex', - ct.SimSpaceAxis.Y: 'Ey', - ct.SimSpaceAxis.Z: 'Ez', - }[props['pol_axis']] + input_sockets, + ) -> td.PointDipole | ct.FlowSignal: + """Compute the point dipole source, given that all inputs are non-symbolic.""" + temporal_shape = input_sockets['Temporal Shape'] + center = input_sockets['Center'] + interpolate = input_sockets['Interpolate'] - return td.PointDipole( - center=input_sockets['Center'], - source_time=input_sockets['Temporal Shape'], - interpolate=input_sockets['Interpolate'], - polarization=pol_axis, - ) + has_temporal_shape = not ct.FlowSignal.check(temporal_shape) + has_center = not ct.FlowSignal.check(center) + has_interpolate = not ct.FlowSignal.check(interpolate) + + if has_temporal_shape and has_center and has_interpolate: + return temporal_shape | center | interpolate + return ct.FlowSignal.FlowPending #################### # - Preview diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py index ba2ae1f..1f48f31 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py @@ -16,15 +16,19 @@ """Implements the `TemporalShapeNode`.""" +import enum import typing as typ import bpy +import numpy as np import sympy as sp import sympy.physics.units as spu import tidy3d as td +from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray +from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset +from blender_maxwell.utils import bl_cache, logger, sim_symbols from blender_maxwell.utils import extra_sympy_units as spux -from blender_maxwell.utils import logger, sim_symbols from ... import contracts as ct from ... import managed_objs, sockets @@ -33,14 +37,10 @@ from .. import base, events log = logger.get(__name__) -_max_e_socket_def = sockets.ExprSocketDef( - mathtype=spux.MathType.Complex, - physical_type=spux.PhysicalType.EField, - default_value=1 + 0j, -) -_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5) - -t_ps = sim_symbols.t(spu.picosecond) +# Select Default Time Unit for Envelope +## -> Chosen to align with the default envelope_time_unit. +## -> This causes it to be correct from the start. +t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0]) class TemporalShapeNode(base.MaxwellSimNode): @@ -63,17 +63,18 @@ class TemporalShapeNode(base.MaxwellSimNode): default_unit=spux.THz, default_value=200, ), + 'max E': sockets.ExprSocketDef( + mathtype=spux.MathType.Complex, + physical_type=spux.PhysicalType.EField, + default_value=1 + 0j, + ), + 'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5), } input_socket_sets: typ.ClassVar = { 'Pulse': { - 'max E': _max_e_socket_def, - 'Offset Time': _offset_socket_def, 'Remove DC': sockets.BoolSocketDef(default_value=True), }, - 'Constant': { - 'max E': _max_e_socket_def, - 'Offset Time': _offset_socket_def, - }, + 'Constant': {}, 'Symbolic': { 't Range': sockets.ExprSocketDef( active_kind=ct.FlowKind.Range, @@ -84,8 +85,8 @@ class TemporalShapeNode(base.MaxwellSimNode): default_steps=100, ), 'Envelope': sockets.ExprSocketDef( - default_symbols=[t_ps], - default_value=10 * t_ps.sp_symbol, + default_symbols=[t_def], + default_value=10 * t_def.sp_symbol, ), }, } @@ -98,6 +99,55 @@ class TemporalShapeNode(base.MaxwellSimNode): 'plot': managed_objs.ManagedBLImage, } + #################### + # - Properties + #################### + active_envelope_time_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_time_units(), + ) + + def search_time_units(self) -> list[ct.BLEnumElement]: + """Compute all valid time units.""" + return [ + (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i) + for i, unit in enumerate(spux.PhysicalType.Time.valid_units) + ] + + @bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'}) + def envelope_time_unit(self) -> spux.Unit | None: + """Gets the current active unit for the envelope time symbol. + + Returns: + The current active `sympy` unit. + + If the socket expression is unitless, this returns `None`. + """ + if self.active_envelope_time_unit is not None: + return spux.unit_str_to_unit(self.active_envelope_time_unit) + + return None + + #################### + # - UI + #################### + def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout): + if ( + self.active_socket_set == 'Symbolic' + and self.inputs.get('Envelope') + and not self.inputs['Envelope'].is_linked + ): + row = layout.row() + row.alignment = 'CENTER' + row.label(text='Envelope Time Unit') + + row = layout.row() + row.prop( + self, + self.blfields['active_envelope_time_unit'], + text='', + toggle=True, + ) + def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: if self.active_socket_set != 'Symbolic': box = layout.box() @@ -118,10 +168,53 @@ class TemporalShapeNode(base.MaxwellSimNode): col.label(text='1 / 2π·σ(𝑓)') #################### - # - FlowKind: Value + # - Events + #################### + @events.on_value_changed( + # Trigger + prop_name={'active_socket_set', 'envelope_time_unit'}, + # Loaded + props={'active_socket_set', 'envelope_time_unit'}, + ) + def on_envelope_time_unit_changed(self, props) -> None: + """Ensure the envelope expression's time symbol has the time unit defined by the node.""" + active_socket_set = props['active_socket_set'] + envelope_time_unit = props['envelope_time_unit'] + if active_socket_set == 'Symbolic': + bl_socket = self.inputs['Envelope'] + wanted_t_sym = sim_symbols.t(envelope_time_unit) + + if not bl_socket.symbols or bl_socket.symbols[0] != wanted_t_sym: + bl_socket.symbols = [wanted_t_sym] + + #################### + # - FlowKind.Value #################### @events.computes_output_socket( 'Temporal Shape', + kind=ct.FlowKind.Value, + # Loaded + output_sockets={'Temporal Shape'}, + output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}}, + ) + def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute a single temporal shape.""" + output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func] + output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params] + + has_output_func = not ct.FlowSignal.check(output_func) + has_output_params = not ct.FlowSignal.check(output_params) + + if has_output_func and has_output_params and not output_params.symbols: + return output_func.realize(output_params) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind: Func + #################### + @events.computes_output_socket( + 'Temporal Shape', + kind=ct.FlowKind.Func, # Loaded props={'active_socket_set'}, input_sockets={ @@ -134,60 +227,178 @@ class TemporalShapeNode(base.MaxwellSimNode): 'Envelope', }, input_socket_kinds={ - 't Range': ct.FlowKind.Range, - 'Envelope': ct.FlowKind.Func, - }, - input_sockets_optional={ - 'max E': True, - 'Offset Time': True, - 'Remove DC': True, - 't Range': True, - 'Envelope': True, - }, - unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, - scale_input_sockets={ - 'max E': 'Tidy3DUnits', - 'μ Freq': 'Tidy3DUnits', - 'σ Freq': 'Tidy3DUnits', - 't Range': 'Tidy3DUnits', - 'Offset Time': 'Tidy3DUnits', + 'max E': ct.FlowKind.Func, + 'μ Freq': ct.FlowKind.Func, + 'σ Freq': ct.FlowKind.Func, + 'Offset Time': ct.FlowKind.Func, + 'Remove DC': ct.FlowKind.Value, + 't Range': ct.FlowKind.Func, + 'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params}, }, ) - def compute_temporal_shape( - self, props, input_sockets, unit_systems + def compute_temporal_shape_func( + self, + props, + input_sockets, ) -> td.GaussianPulse: - match props['active_socket_set']: - case 'Pulse': - return td.GaussianPulse( - amplitude=sp.re(input_sockets['max E']), - phase=sp.im(input_sockets['max E']), - freq0=input_sockets['μ Freq'], - fwidth=input_sockets['σ Freq'], - offset=input_sockets['Offset Time'], - remove_dc_component=input_sockets['Remove DC'], - ) + """Compute a single temporal shape from non-parameterized inputs.""" + mean_freq = input_sockets['μ Freq'] + std_freq = input_sockets['σ Freq'] + max_e = input_sockets['max E'] + offset = input_sockets['Offset Time'] - case 'Constant': - return td.ContinuousWave( - amplitude=sp.re(input_sockets['max E']), - phase=sp.im(input_sockets['max E']), - freq0=input_sockets['μ Freq'], - fwidth=input_sockets['σ Freq'], - offset=input_sockets['Offset Time'], - ) + has_mean_freq = not ct.FlowSignal.check(mean_freq) + has_std_freq = not ct.FlowSignal.check(std_freq) + has_max_e = not ct.FlowSignal.check(max_e) + has_offset = not ct.FlowSignal.check(offset) - case 'Symbolic': - lzrange = input_sockets['t Range'] - envelope_ps = input_sockets['Envelope'].func_jax + if has_mean_freq and has_std_freq and has_max_e and has_offset: + common_func = ( + max_e.scale_to_unit_system(ct.UNITS_TIDY3D) + | mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D) + | std_freq.scale_to_unit_system(ct.UNITS_TIDY3D) + | offset ## Already unitless + ) + match props['active_socket_set']: + case 'Pulse': + remove_dc = input_sockets['Remove DC'] - return td.CustomSourceTime.from_values( - freq0=input_sockets['μ Freq'], - fwidth=input_sockets['σ Freq'], - values=envelope_ps( - lzrange.rescale_to_unit(spu.ps).realize_array.values - ), - dt=input_sockets['t Range'].realize_step_size(), - ) + has_remove_dc = not ct.FlowSignal.check(remove_dc) + + if has_remove_dc: + return common_func.compose_within( + lambda els: td.GaussianPulse( + amplitude=complex(els[0]).real, + phase=complex(els[0]).imag, + freq0=els[1], + fwidth=els[2], + offset=els[3], + remove_dc_component=remove_dc, + ), + ) + + case 'Constant': + return common_func.compose_within( + lambda els: td.GaussianPulse( + amplitude=complex(els[0]).real, + phase=complex(els[0]).imag, + freq0=els[1], + fwidth=els[2], + offset=els[3], + ), + ) + + case 'Symbolic': + t_range = input_sockets['t Range'] + envelope = input_sockets['Envelope'][ct.FlowKind.Func] + envelope_params = input_sockets['Envelope'][ct.FlowKind.Params] + + has_t_range = not ct.FlowSignal.check(t_range) + has_envelope = not ct.FlowSignal.check(envelope) + has_envelope_params = not ct.FlowSignal.check(envelope_params) + + if ( + has_t_range + and has_envelope + and has_envelope_params + and len(envelope_params.symbols) == 1 + ## TODO: Allow unrealized envelope symbols + and any( + sym.physical_type is spux.PhysicalType.Time + for sym in envelope_params.symbols + ) + ): + envelope_time_unit = next( + sym.unit + for sym in envelope_params.symbols + if sym.physical_type is spux.PhysicalType.Time + ) + + # Deduce Partially Realized Envelope Function + ## -> We need a pure-numerical function w/pre-realized stuff baked in. + ## -> 'realize_partial' does this for us. + envelope_realizer = envelope.realize_partial(envelope_params) + + # Compose w/Envelope Function + ## -> First, the numerical time values must be converted. + ## -> This ensures that the raw array is compatible w/the envelope. + ## -> Then, we can compose w/the purely numerical 'envelope_realizer'. + ## -> Because of the checks, we've guaranteed that all this is correct. + return ( + common_func ## 1 | freq0, 2 | fwidth, 3 | offset + | t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4 + | t_range.scale_to_unit(envelope_time_unit).compose_within( + lambda t: envelope_realizer(t) + ) ## 5 + ).compose_within( + lambda els: td.CustomSourceTime( + amplitude=complex(els[0]).real, + phase=complex(els[0]).imag, + freq0=els[1], + fwidth=els[2], + offset=els[3], + source_time_dataset=td_TimeDataset( + values=td_TimeDataArray( + els[5], coords={'t': np.array(els[4])} + ) + ), + ) + ) + + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind: Params + #################### + @events.computes_output_socket( + 'Temporal Shape', + kind=ct.FlowKind.Params, + # Loaded + props={'active_socket_set', 'envelope_time_unit'}, + input_sockets={ + 'max E', + 'μ Freq', + 'σ Freq', + 'Offset Time', + 't Range', + }, + input_socket_kinds={ + 'max E': ct.FlowKind.Params, + 'μ Freq': ct.FlowKind.Params, + 'σ Freq': ct.FlowKind.Params, + 'Offset Time': ct.FlowKind.Params, + 't Range': ct.FlowKind.Params, + }, + ) + def compute_temporal_shape_params( + self, + props, + input_sockets, + ) -> td.GaussianPulse: + """Compute a single temporal shape from non-parameterized inputs.""" + mean_freq = input_sockets['μ Freq'] + std_freq = input_sockets['σ Freq'] + max_e = input_sockets['max E'] + offset = input_sockets['Offset Time'] + + has_mean_freq = not ct.FlowSignal.check(mean_freq) + has_std_freq = not ct.FlowSignal.check(std_freq) + has_max_e = not ct.FlowSignal.check(max_e) + has_offset = not ct.FlowSignal.check(offset) + + if has_mean_freq and has_std_freq and has_max_e and has_offset: + common_params = max_e | mean_freq | std_freq | offset + match props['active_socket_set']: + case 'Pulse' | 'Constant': + return common_params + + case 'Symbolic': + t_range = input_sockets['t Range'] + has_t_range = not ct.FlowSignal.check(t_range) + + if has_t_range: + return common_params | t_range | t_range + return ct.FlowSignal.FlowPending #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py index 6dbf856..0d955df 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py @@ -88,28 +88,27 @@ class BoxStructureNode(base.MaxwellSimNode): 'Structure', kind=ct.FlowKind.Value, # Loaded - props={'differentiable'}, input_sockets={'Medium', 'Center', 'Size'}, output_sockets={'Structure'}, output_socket_kinds={'Structure': ct.FlowKind.Params}, ) - def compute_value(self, props, input_sockets, output_sockets) -> td.Box: - output_params = output_sockets['Structure'] + def compute_value(self, input_sockets, output_sockets) -> td.Box: + """Compute a single box structure object, given that all inputs are non-symbolic.""" center = input_sockets['Center'] size = input_sockets['Size'] medium = input_sockets['Medium'] + output_params = output_sockets['Structure'] - has_output_params = not ct.FlowSignal.check(output_params) has_center = not ct.FlowSignal.check(center) has_size = not ct.FlowSignal.check(size) has_medium = not ct.FlowSignal.check(medium) + has_output_params = not ct.FlowSignal.check(output_params) if ( has_center and has_size and has_medium and has_output_params - and not props['differentiable'] and not output_params.symbols ): return td.Structure( @@ -138,7 +137,8 @@ class BoxStructureNode(base.MaxwellSimNode): output_sockets={'Structure'}, output_socket_kinds={'Structure': ct.FlowKind.Params}, ) - def compute_lazy_structure(self, props, input_sockets, output_sockets) -> td.Box: + def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box: + """Compute a possibly-differentiable function, producing a box structure from the input parameters.""" output_params = output_sockets['Structure'] center = input_sockets['Center'] size = input_sockets['Size'] @@ -149,14 +149,8 @@ class BoxStructureNode(base.MaxwellSimNode): has_size = not ct.FlowSignal.check(size) has_medium = not ct.FlowSignal.check(medium) - differentiable = props['differentiable'] - if ( - has_output_params - and has_center - and has_size - and has_medium - and differentiable == output_params.is_differentiable - ): + if has_output_params and has_center and has_size and has_medium: + differentiable = props['differentiable'] if differentiable: return (center | size | medium).compose_within( enclosing_func=lambda els: tdadj.JaxStructure( @@ -169,6 +163,12 @@ class BoxStructureNode(base.MaxwellSimNode): supports_jax=True, ) return (center | size | medium).compose_within( + ## TODO: Unit conversion within the composed function?? + ## -- We do need Tidy3D to be given ex. micrometers in particular. + ## -- But the previous numerical output might not be micrometers. + ## -- There must be a way to add a conversion in, without strangeness. + ## -- Ex. can compose_within() take a unit system? + ## -- This would require enclosing_func=lambda els: td.Structure( geometry=td.Box( center=tuple(els[0].flatten()), @@ -205,13 +205,7 @@ class BoxStructureNode(base.MaxwellSimNode): has_medium = not ct.FlowSignal.check(medium) if has_center and has_size and has_medium: - if props['differentiable'] == ( - center.is_differentiable - and size.is_differentiable - and medium.is_differentiable - ): - return center | size | medium - return ct.FlowSignal.FlowPending + return center | size | medium return ct.FlowSignal.FlowPending #################### @@ -226,6 +220,7 @@ class BoxStructureNode(base.MaxwellSimNode): output_socket_kinds={'Structure': ct.FlowKind.Params}, ) def compute_previews(self, props, output_sockets): + """Mark the managed preview object when recursively linked to a viewer.""" output_params = output_sockets['Structure'] has_output_params = not ct.FlowSignal.check(output_params) @@ -245,10 +240,14 @@ class BoxStructureNode(base.MaxwellSimNode): ) def on_inputs_changed(self, managed_objs, input_sockets, output_sockets): output_params = output_sockets['Structure'] + center = input_sockets['Center'] + has_output_params = not ct.FlowSignal.check(output_params) - if has_output_params and not output_params.symbols: + has_center = not ct.FlowSignal.check(center) + if has_center and has_output_params and not output_params.symbols: + ## TODO: There are strategies for handling examples of symbol values. + # Push Loose Input Values to GeoNodes Modifier - center = input_sockets['Center'] managed_objs['modifier'].bl_modifier( 'NODES', { diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index 2cb3138..ea66346 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -43,17 +43,28 @@ class SocketDef(pyd.BaseModel, abc.ABC): """ socket_type: ct.SocketType + active_kind: typ.Literal[ + ct.FlowKind.Value, + ct.FlowKind.Array, + ct.FlowKind.Range, + ct.FlowKind.Func, + ] = ct.FlowKind.Value + #################### + # - Socket Interaction + #################### def preinit(self, bl_socket: bpy.types.NodeSocket) -> None: """Pre-initialize a real Blender node socket from this socket definition. Parameters: bl_socket: The Blender node socket to alter using data from this SocketDef. """ - log.debug('%s: Start Socket Preinit', bl_socket.bl_label) + # log.debug('%s: Start Socket Preinit', bl_socket.bl_label) bl_socket.reset_instance_id() bl_socket.regenerate_dynamic_field_persistance() - log.debug('%s: End Socket Preinit', bl_socket.bl_label) + + bl_socket.active_kind = self.active_kind + # log.debug('%s: End Socket Preinit', bl_socket.bl_label) def postinit(self, bl_socket: bpy.types.NodeSocket) -> None: """Pre-initialize a real Blender node socket from this socket definition. @@ -61,12 +72,12 @@ class SocketDef(pyd.BaseModel, abc.ABC): Parameters: bl_socket: The Blender node socket to alter using data from this SocketDef. """ - log.debug('%s: Start Socket Postinit', bl_socket.bl_label) + # log.debug('%s: Start Socket Postinit', bl_socket.bl_label) bl_socket.is_initializing = False bl_socket.on_active_kind_changed() bl_socket.on_socket_props_changed(set(bl_socket.blfields)) bl_socket.on_data_changed(set(ct.FlowKind)) - log.debug('%s: End Socket Postinit', bl_socket.bl_label) + # log.debug('%s: End Socket Postinit', bl_socket.bl_label) @abc.abstractmethod def init(self, bl_socket: bpy.types.NodeSocket) -> None: @@ -76,6 +87,43 @@ class SocketDef(pyd.BaseModel, abc.ABC): bl_socket: The Blender node socket to alter using data from this SocketDef. """ + #################### + # - Comparison + #################### + def compare(self, bl_socket: bpy.types.NodeSocket) -> bool: + """Whether this `SocketDef` can be considered to uniquely define the given `bl_socket`. + + The general criteria for "uniquely defines" is whether **the same `bl_socket`** could be created using this `SocketDef`. + The extent to which user-altered properties are considered in this regard is a matter of taste, encapsulated entirely within `self.local_compare()`. + + Notes: + Used when determining whether to replace sockets with newer variants when synchronizing changes. + + **NOTE**: Removing/replacing loose input sockets + + Parameters: + bl_socket: The Blender node socket to alter using data from this SocketDef. + """ + return ( + bl_socket.socket_type is self.socket_type + and bl_socket.active_kind is self.active_kind + and self.local_compare(bl_socket) + ) + + def local_compare(self, bl_socket: bpy.types.NodeSocket) -> None: + """Compare this `SocketDef` to an established `bl_socket` in a manner specific to the node. + + Notes: + Run by `self.compare()`. + Optionally overriden by individual sockets. + + When not overridden, it will always return `False`, indicating that the socket is _never_ uniquely defined by this `SocketDef`. + + Parameters: + bl_socket: The Blender node socket to alter using data from this SocketDef. + """ + return False + #################### # - Serialization #################### @@ -426,8 +474,34 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance): Parameters: socket_kinds: The altered `ct.FlowKind`s flowing through. """ + # Run Socket Callbacks self.on_socket_data_changed(socket_kinds) + # Mark Active FlowKind Links as Invalid + ## -> Mark link as invalid (very red) if a FlowSignal is traveling. + ## -> This helps explain why whatever isn't working isn't working. + ## -> TODO: We need a different approach. + # log.debug( + # '[%s] Checking FlowKind Validity (socket_kinds=%s)', + # self.name, + # str(socket_kinds), + # ) + # if self.is_linked and not self.is_output: + # link = self.links[0] + # linked_flow = self.compute_data(kind=self.active_kind) + + # if ( + # link.is_valid + # and self.active_kind in socket_kinds + # and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending) + # ): + # node_tree = self.id_data + # node_tree.report_link_validity(link, False) + + # elif not link.is_valid: + # node_tree = self.id_data + # node_tree.report_link_validity(link, True) + def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None: """Called when `ct.FlowEvent.DataChanged` flows through this socket. @@ -479,7 +553,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance): The value of `ct.FlowEvent.flow_direction[event]` (`input` or `output`) determines the direction that an event flows. """ # log.debug( - # '[%s] [%s] Triggered (socket_kinds=%s)', + # '[%s] [%s] Socket Triggered (socket_kinds=%s)', # self.name, # event, # str(socket_kinds), @@ -757,7 +831,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance): linked_values = [link.from_socket.compute_data(kind) for link in self.links] # Return Single Value / List of Values - ## -> Multi-input sockets are not yet supported. + ## -> Multi-input sockets are not (yet) supported. if linked_values: return linked_values[0] @@ -891,10 +965,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance): # FlowKind Draw Row col = row.column(align=True) { + ct.FlowKind.Capabilities: lambda *_: None, + ct.FlowKind.Previews: lambda *_: None, ct.FlowKind.Value: self.draw_value, ct.FlowKind.Array: self.draw_array, ct.FlowKind.Range: self.draw_lazy_range, ct.FlowKind.Func: self.draw_lazy_func, + ct.FlowKind.Params: lambda *_: None, + ct.FlowKind.Info: lambda *_: None, }[self.active_kind](col) # Info Drawing diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/bool.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/bool.py index d2c4f4b..df108f9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/bool.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/bool.py @@ -51,6 +51,16 @@ class BoolBLSocket(base.MaxwellSimSocket): def value(self, value: bool) -> None: self.raw_value = value + @bl_cache.cached_bl_property(depends_on={'value'}) + def lazy_func(self) -> ct.FuncFlow: + return ct.FuncFlow( + func=lambda: self.value, + ) + + @bl_cache.cached_bl_property() + def params(self) -> ct.FuncFlow: + return ct.ParamsFlow() + #################### # - Socket Configuration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index cf7d47e..1cdc9d7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -130,6 +130,7 @@ class ExprBLSocket(base.MaxwellSimSocket): 'physical_type', 'unit', 'size', + 'value', } ) def output_sym(self) -> sim_symbols.SimSymbol | None: @@ -140,13 +141,29 @@ class ExprBLSocket(base.MaxwellSimSocket): Raises: NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`. """ - if self.symbols: - if self.active_kind in [ct.FlowKind.Value, ct.FlowKind.Func]: + match self.active_kind: + case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols: return self._parse_expr_symbol( self._parse_expr_str(self.raw_value_spstr) ) - if self.active_kind is ct.FlowKind.Range: + case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols: + return sim_symbols.SimSymbol( + sym_name=self.output_name, + mathtype=self.mathtype, + physical_type=self.physical_type, + unit=self.unit, + rows=self.size.rows, + cols=self.size.cols, + exclude_zero=( + not self.value.is_zero + if self.value.is_zero is not None + else False + ), + ## TODO: Does this work for matrix elements? + ) + + case ct.FlowKind.Range if self.symbols: ## TODO: Support RangeFlow ## -- It's hard; we need a min-span set over bound domains. ## -- We... Don't use this anywhere. Yet? @@ -159,20 +176,37 @@ class ExprBLSocket(base.MaxwellSimSocket): msg = 'RangeFlow support not yet implemented for when self.symbols is not empty' raise NotImplementedError(msg) - raise NotImplementedError + case ct.FlowKind.Range if not self.symbols: + return sim_symbols.SimSymbol( + sym_name=self.output_name, + mathtype=self.mathtype, + physical_type=self.physical_type, + unit=self.unit, + rows=self.lazy_range.steps, + cols=1, + exclude_zero=not self.lazy_range.is_always_nonzero, + ) - return sim_symbols.SimSymbol( - sym_name=self.output_name, - mathtype=self.mathtype, - physical_type=self.physical_type, - unit=self.unit, - rows=self.size.rows, - cols=self.size.cols, - ) + #################### + # - Value|Range Swapper + #################### + use_value_range_swapper: bool = bl_cache.BLField(False) + selected_value_range: ct.FlowKind = bl_cache.BLField( + enum_cb=lambda self, _: self._value_or_range(), + ) + + def _value_or_range(self): + return [ + flow_kind.bl_enum_element(i) + for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range]) + ] #################### # - Symbols #################### + lazy_range_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.Expr + ) output_name: sim_symbols.SimSymbolName = bl_cache.BLField( sim_symbols.SimSymbolName.Expr ) @@ -343,7 +377,7 @@ class ExprBLSocket(base.MaxwellSimSocket): See `MaxwellSimTree` for more detail on the link callbacks. """ - ## NODE: Depends on suppressed on_prop_changed + ## NOTE: Depends on suppressed on_prop_changed if ct.FlowKind.Info in socket_kinds: info = self.compute_data(kind=ct.FlowKind.Info) @@ -371,7 +405,10 @@ class ExprBLSocket(base.MaxwellSimSocket): See `MaxwellSimTree` for more detail on the link callbacks. """ - ## NODE: Depends on suppressed on_prop_changed + ## NOTE: Depends on suppressed on_prop_changed + if ('selected_value_range', 'invalidate') in cleared_blfields: + self.active_kind = self.selected_value_range + self.on_active_kind_changed() # Conditional Unit-Conversion ## -> This is niche functionality, but the only way to convert units. @@ -757,7 +794,6 @@ class ExprBLSocket(base.MaxwellSimSocket): @bl_cache.cached_bl_property( depends_on={ 'value', - 'symbols', 'sorted_sp_symbols', 'sorted_symbols', 'output_sym', @@ -769,82 +805,87 @@ class ExprBLSocket(base.MaxwellSimSocket): If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`. Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`. """ - # Symbolic - ## -> `self.value` is guaranteed to be an expression with unknowns. - ## -> The function computes `self.value` with unknowns as arguments. - if self.symbols: - value = self.value - has_value = not ct.FlowSignal.check(value) + if self.output_sym is not None: + match self.active_kind: + case ct.FlowKind.Value | ct.FlowKind.Func if ( + self.sorted_symbols and not ct.FlowSignal.check(self.value) + ): + return ct.FuncFlow( + func=sp.lambdify( + self.sorted_sp_symbols, + self.output_sym.conform(self.value, strip_unit=True), + 'jax', + ), + func_args=list(self.sorted_symbols), + func_output=self.output_sym, + supports_jax=True, + ) - output_sym = self.output_sym - if output_sym is not None and has_value: - return ct.FuncFlow( - func=sp.lambdify( - self.sorted_sp_symbols, - output_sym.conform(value, strip_unit=True), - 'jax', - ), - func_args=list(self.sorted_symbols), - supports_jax=True, - ) - return ct.FlowSignal.FlowPending + case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols: + return ct.FuncFlow( + func=lambda v: v, + func_args=[self.output_sym], + func_output=self.output_sym, + supports_jax=True, + ) - # Constant - ## -> When a `self.value` has no unknowns, use a dummy function. - ## -> ("Dummy" as in returns the same argument that it takes). - ## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim. - ## -> Generally only useful for operations with other expressions. - return ct.FuncFlow( - func=lambda v: v, - func_args=[self.output_sym], - supports_jax=True, - ) + case ct.FlowKind.Range if self.sorted_symbols: + msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty' + raise NotImplementedError(msg) - @bl_cache.cached_bl_property(depends_on={'sorted_symbols'}) - def is_differentiable(self) -> bool: - """Whether all symbols are differentiable. + case ct.FlowKind.Range if ( + not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) + ): + return ct.FuncFlow( + func=lambda v: v, + func_args=[self.output_sym], + func_output=self.output_sym, + supports_jax=True, + ) - If there are no symbols, then there is nothing to differentiate, and thus the expression is differentiable. - """ - if not self.sorted_symbols: - return True + return ct.FlowSignal.FlowPending - return all( - sym.mathtype in [spux.MathType.Real, spux.MathType.Complex] - for sym in self.sorted_symbols - ) - - @bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym', 'value'}) + @bl_cache.cached_bl_property( + depends_on={'sorted_symbols', 'output_sym', 'value', 'lazy_range'} + ) def params(self) -> ct.ParamsFlow: """Returns parameter symbols/values to accompany `self.lazy_func`. If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use). Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`. """ - # Symbolic - ## -> The Expr socket does not declare actual values for the symbols. - ## -> They should be realized later, ex. in a Viz node. - ## -> Therefore, we just dump the symbols. Easy! - ## -> NOTE: func_args must have the same symbol order as was lambdified. - if self.sorted_symbols: - output_sym = self.output_sym - if output_sym is not None: - return ct.ParamsFlow( - arg_targets=list(self.sorted_symbols), - func_args=[sym.sp_symbol for sym in self.sorted_symbols], - symbols=self.sorted_symbols, - is_differentiable=self.is_differentiable, - ) - return ct.FlowSignal.FlowPending + output_sym = self.output_sym + if output_sym is not None: + match self.active_kind: + case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols: + return ct.ParamsFlow( + arg_targets=list(self.sorted_symbols), + func_args=[sym.sp_symbol for sym in self.sorted_symbols], + symbols=set(self.sorted_symbols), + ) - # Constant - ## -> Simply pass self.value verbatim as a function argument. - ## -> Easy dice, easy life! - return ct.ParamsFlow( - arg_targets=[self.output_sym], - func_args=[self.value], - is_differentiable=self.is_differentiable, - ) + case ct.FlowKind.Value | ct.FlowKind.Func if ( + not self.sorted_symbols and not ct.FlowSignal.check(self.value) + ): + return ct.ParamsFlow( + arg_targets=[self.output_sym], + func_args=[self.value], + ) + + case ct.FlowKind.Range if self.sorted_symbols: + msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty' + raise NotImplementedError(msg) + + case ct.FlowKind.Range if ( + not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) + ): + return ct.ParamsFlow( + arg_targets=[self.output_sym], + func_args=[self.output_sym.sp_symbol_matsym], + symbols={self.output_sym}, + ).realize_partial({self.output_sym: self.lazy_range}) + + return ct.FlowSignal.FlowPending @bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'}) def info(self) -> ct.InfoFlow: @@ -858,21 +899,33 @@ class ExprBLSocket(base.MaxwellSimSocket): Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. """ - # Constant - ## -> The input SimSymbols become continuous dimensional indices. - ## -> All domain validity information is defined on the SimSymbol keys. - if self.sorted_symbols: - output_sym = self.output_sym - if output_sym is not None: - return ct.InfoFlow( - dims={sym: None for sym in self.sorted_symbols}, - output=self.output_sym, - ) - return ct.FlowSignal.FlowPending + output_sym = self.output_sym + if output_sym is not None: + match self.active_kind: + case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols: + return ct.InfoFlow( + dims={sym: None for sym in self.sorted_symbols}, + output=self.output_sym, + ) - # Constant - ## -> We only need the output symbol to describe the raw data. - return ct.InfoFlow(output=self.output_sym) + case ct.FlowKind.Value | ct.FlowKind.Func if ( + not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) + ): + return ct.InfoFlow(output=self.output_sym) + + case ct.FlowKind.Range if self.sorted_symbols: + msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty' + raise NotImplementedError(msg) + + case ct.FlowKind.Range if ( + not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) + ): + return ct.InfoFlow( + dims={self.output_sym: self.lazy_range}, + output=self.output_sym.update(rows=1), + ) + + return ct.FlowSignal.FlowPending #################### # - FlowKind: Capabilities @@ -1039,6 +1092,9 @@ class ExprBLSocket(base.MaxwellSimSocket): However, `draw_value` may also be called by the `draw_*` methods of other `FlowKinds`, who may choose to layer more flexibility around this base UI. """ + if self.use_value_range_swapper: + col.prop(self, self.blfields['selected_value_range'], text='') + if self.symbols: col.prop(self, self.blfields['raw_value_spstr'], text='') @@ -1097,6 +1153,9 @@ class ExprBLSocket(base.MaxwellSimSocket): If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps. As such, `self.steps` won't be exposed in the UI. """ + if self.use_value_range_swapper: + col.prop(self, self.blfields['selected_value_range'], text='') + if self.symbols: col.prop(self, self.blfields['raw_min_spstr'], text='') col.prop(self, self.blfields['raw_max_spstr'], text='') @@ -1198,13 +1257,11 @@ class ExprBLSocket(base.MaxwellSimSocket): # - Socket Configuration #################### class ExprSocketDef(base.SocketDef): + """Interface for defining an `ExprSocket`.""" + socket_type: ct.SocketType = ct.SocketType.Expr - active_kind: typ.Literal[ - ct.FlowKind.Value, - ct.FlowKind.Range, - ct.FlowKind.Func, - ] = ct.FlowKind.Value output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr + use_value_range_swapper: bool = False # Socket Interface size: spux.NumberSize1D = spux.NumberSize1D.Scalar @@ -1458,7 +1515,7 @@ class ExprSocketDef(base.SocketDef): # Check ActiveKind and Size ## -> NOTE: This doesn't protect against dynamic changes to either. if ( - self.active_kind == ct.FlowKind.Range + self.active_kind is ct.FlowKind.Range and self.size is not spux.NumberSize1D.Scalar ): msg = "Can't have a non-Scalar size when Range is set as the active kind." @@ -1504,9 +1561,9 @@ class ExprSocketDef(base.SocketDef): # - Initialization #################### def init(self, bl_socket: ExprBLSocket) -> None: - bl_socket.active_kind = self.active_kind bl_socket.output_name = self.output_name bl_socket.use_linked_capabilities = True + bl_socket.use_value_range_swapper = self.use_value_range_swapper # Socket Interface ## -> Recall that auto-updates are turned off during init() @@ -1543,6 +1600,25 @@ class ExprSocketDef(base.SocketDef): # Info Draw bl_socket.use_info_draw = True + def local_compare(self, bl_socket: ExprBLSocket) -> None: + """Determine whether an updateable socket should be re-initialized from this `SocketDef`.""" + + def cmp(attr: str): + return getattr(bl_socket, attr) == getattr(self, attr) + + return ( + bl_socket.use_linked_capabilities + and cmp('output_name') + and cmp('use_value_range_swapper') + and cmp('size') + and cmp('mathtype') + and cmp('physical_type') + and cmp('show_func_ui') + and cmp('show_info_columns') + and cmp('show_name_selector') + and bl_socket.use_info_draw + ) + #################### # - Blender Registration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/bound_conds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/bound_conds.py index ea69736..97c434a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/bound_conds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/bound_conds.py @@ -117,6 +117,16 @@ class MaxwellBoundCondsBLSocket(base.MaxwellSimSocket): ), ) + @bl_cache.cached_bl_property(depends_on={'value'}) + def lazy_func(self) -> ct.FuncFlow: + return ct.FuncFlow( + func=lambda: self.value, + ) + + @bl_cache.cached_bl_property() + def params(self) -> ct.FuncFlow: + return ct.ParamsFlow() + #################### # - Socket Configuration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/fdtd_sim.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/fdtd_sim.py index ff59720..d4d0801 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/fdtd_sim.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/fdtd_sim.py @@ -14,6 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import typing as typ + from ... import contracts as ct from .. import base @@ -32,6 +34,9 @@ class MaxwellFDTDSimSocketDef(base.SocketDef): def init(self, bl_socket: MaxwellFDTDSimBLSocket) -> None: pass + def local_compare(self, _: MaxwellFDTDSimBLSocket) -> None: + return True + #################### # - Blender Registration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py index 4b92a24..521605d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py @@ -73,7 +73,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket): def value(self, eps_rel: tuple[float, float]) -> None: self.eps_rel = eps_rel - @bl_cache.cached_bl_property(depends_on={'value', 'differentiable'}) + @bl_cache.cached_bl_property(depends_on={'value'}) def lazy_func(self) -> ct.FuncFlow: return ct.FuncFlow( func=lambda: self.value, @@ -82,7 +82,7 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket): @bl_cache.cached_bl_property(depends_on={'differentiable'}) def params(self) -> ct.FuncFlow: - return ct.ParamsFlow(is_differentiable=self.differentiable) + return ct.ParamsFlow() #################### # - UI diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor.py index 4e2754d..255e86c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor.py @@ -29,11 +29,11 @@ class MaxwellMonitorBLSocket(base.MaxwellSimSocket): class MaxwellMonitorSocketDef(base.SocketDef): socket_type: ct.SocketType = ct.SocketType.MaxwellMonitor - is_list: bool = False - def init(self, bl_socket: MaxwellMonitorBLSocket) -> None: - if self.is_list: - bl_socket.active_kind = ct.FlowKind.Array + pass + + def local_compare(self, _: MaxwellMonitorBLSocket) -> None: + return True #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/sim_grid.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/sim_grid.py index 80e4635..7d34081 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/sim_grid.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/sim_grid.py @@ -55,6 +55,16 @@ class MaxwellSimGridBLSocket(base.MaxwellSimSocket): min_steps_per_wvl=self.min_steps_per_wl, ) + @bl_cache.cached_bl_property(depends_on={'value'}) + def lazy_func(self) -> ct.FuncFlow: + return ct.FuncFlow( + func=lambda: self.value, + ) + + @bl_cache.cached_bl_property() + def params(self) -> ct.FuncFlow: + return ct.ParamsFlow() + #################### # - Socket Configuration diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/source.py index d13e7de..b7f4c7c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/source.py @@ -29,11 +29,11 @@ class MaxwellSourceBLSocket(base.MaxwellSimSocket): class MaxwellSourceSocketDef(base.SocketDef): socket_type: ct.SocketType = ct.SocketType.MaxwellSource - is_list: bool = False - def init(self, bl_socket: MaxwellSourceBLSocket) -> None: - if self.is_list: - bl_socket.active_kind = ct.FlowKind.Array + pass + + def local_compare(self, _: MaxwellSourceBLSocket) -> None: + return True #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/structure.py index 7b3ae06..29cc1ad 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/structure.py @@ -29,11 +29,11 @@ class MaxwellStructureBLSocket(base.MaxwellSimSocket): class MaxwellStructureSocketDef(base.SocketDef): socket_type: ct.SocketType = ct.SocketType.MaxwellStructure - is_list: bool = False - def init(self, bl_socket: MaxwellStructureBLSocket) -> None: - if self.is_list: - bl_socket.active_kind = ct.FlowKind.Array + pass + + def local_compare(self, _: MaxwellStructureBLSocket) -> None: + return True #################### diff --git a/src/blender_maxwell/utils/bl_cache/__init__.py b/src/blender_maxwell/utils/bl_cache/__init__.py index fa81d92..0a8120a 100644 --- a/src/blender_maxwell/utils/bl_cache/__init__.py +++ b/src/blender_maxwell/utils/bl_cache/__init__.py @@ -16,10 +16,10 @@ """Package providing various tools to handle cached data on Blender objects, especially nodes and node socket classes.""" +from ..keyed_cache import KeyedCache, keyed_cache from .bl_field import BLField from .bl_prop import BLProp, BLPropType from .cached_bl_property import CachedBLProperty, cached_bl_property -from .keyed_cache import KeyedCache, keyed_cache from .managed_cache import invalidate_nonpersist_instance_id from .signal import Signal diff --git a/src/blender_maxwell/utils/bl_instance.py b/src/blender_maxwell/utils/bl_instance.py index f17c12a..1a0b1b6 100644 --- a/src/blender_maxwell/utils/bl_instance.py +++ b/src/blender_maxwell/utils/bl_instance.py @@ -21,6 +21,7 @@ from types import MappingProxyType import bpy from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils.keyed_cache import keyed_cache InstanceID: typ.TypeAlias = str ## Stringified UUID4 @@ -220,11 +221,14 @@ class BLInstance: for str_search_prop_name in self.blfields_str_search: setattr(self, str_search_prop_name, bl_cache.Signal.ResetStrSearch) + @keyed_cache( + exclude={'self'}, ## No dynamic elements of 'self' can be used. + ) def trace_blfields_to_clear( self, prop_name: str, - prev_blfields_to_clear: list[ - tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']] + prev_blfields_to_clear: tuple[ + tuple[str, typ.Literal['invalidate', 'reset_enum', 'reset_strsearch']], ... ] = (), ) -> list[str]: """Invalidates all properties that depend on `prop_name`. @@ -239,7 +243,7 @@ class BLInstance: All of these are filled when creating the `BLInstance` subclass, using `self.declare_blfield_dep()`, generally via the `BLField` descriptor (which internally uses `BLProp`). """ if prev_blfields_to_clear: - blfields_to_clear = prev_blfields_to_clear.copy() + blfields_to_clear = list(prev_blfields_to_clear) else: blfields_to_clear = [] @@ -268,7 +272,7 @@ class BLInstance: if dst_prop_name in self.blfields: blfields_to_clear += self.trace_blfields_to_clear( dst_prop_name, - prev_blfields_to_clear=blfields_to_clear, + prev_blfields_to_clear=tuple(blfields_to_clear), ) match (bool(prev_blfields_to_clear), bool(blfields_to_clear)): @@ -297,7 +301,7 @@ class BLInstance: ## -> As such, deduplication would not be wrong, just extraneous. ## -> Since invalidation is in a hot-loop, don't do such things. case (True, True): - return blfields_to_clear + return list(reversed(dict.fromkeys(reversed(blfields_to_clear)))) def clear_blfields_after(self, prop_name: str) -> list[str]: """Clear (invalidate) all `BLField`s that have become invalid as a result of a change to `prop_name`. diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 205cb71..6ca4412 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -17,6 +17,7 @@ """Useful image processing operations for use in the addon.""" import enum +import functools import typing as typ import jax @@ -26,7 +27,6 @@ import matplotlib import matplotlib.axis as mpl_ax import matplotlib.backends.backend_agg import matplotlib.figure -import numpy as np import seaborn as sns from blender_maxwell import contracts as ct @@ -138,7 +138,7 @@ def rgba_image_from_2d_map( #################### # - MPL Helpers #################### -# @functools.lru_cache(maxsize=16) +@functools.lru_cache(maxsize=4) def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int): fig = matplotlib.figure.Figure( figsize=[width_inches, height_inches], dpi=dpi, layout='tight' diff --git a/src/blender_maxwell/utils/bl_cache/keyed_cache.py b/src/blender_maxwell/utils/keyed_cache.py similarity index 95% rename from src/blender_maxwell/utils/bl_cache/keyed_cache.py rename to src/blender_maxwell/utils/keyed_cache.py index d1b8758..bd9c3db 100644 --- a/src/blender_maxwell/utils/bl_cache/keyed_cache.py +++ b/src/blender_maxwell/utils/keyed_cache.py @@ -18,11 +18,15 @@ import functools import inspect import typing as typ -from blender_maxwell.utils import bl_instance, logger, serialize +from blender_maxwell.utils import logger, serialize log = logger.get(__name__) +class BLInstance(typ.Protocol): + instance_id: str + + class KeyedCache: def __init__( self, @@ -75,8 +79,8 @@ class KeyedCache: def __get__( self, - bl_instance: bl_instance.BLInstance | None, - owner: type[bl_instance.BLInstance], + bl_instance: BLInstance | None, + owner: type[BLInstance], ) -> typ.Callable: _func = functools.partial(self, bl_instance) _func.invalidate = functools.partial( @@ -110,7 +114,7 @@ class KeyedCache: def invalidate( self, - bl_instance: bl_instance.BLInstance | None, + bl_instance: BLInstance | None, **arguments: dict[str, typ.Any], ) -> dict[str, typ.Any]: # Determine Wildcard Arguments diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py index 3be4750..1efaec3 100644 --- a/src/blender_maxwell/utils/sim_symbols.py +++ b/src/blender_maxwell/utils/sim_symbols.py @@ -264,13 +264,16 @@ class SimSymbol(pyd.BaseModel): interval_closed_im: tuple[bool, bool] = (False, False) #################### - # - Labels + # - Core #################### @functools.cached_property def name(self) -> str: """Usable name for the symbol.""" return self.sym_name.name + #################### + # - Labels + #################### @functools.cached_property def name_pretty(self) -> str: """Pretty (possibly unicode) name for the thing.""" @@ -307,6 +310,8 @@ class SimSymbol(pyd.BaseModel): @functools.cached_property def plot_label(self) -> str: """Pretty plot-oriented label.""" + if self.unit is None: + return self.name_pretty return f'{self.name_pretty} ({self.unit_label})' #################### @@ -420,6 +425,11 @@ class SimSymbol(pyd.BaseModel): @functools.cached_property def is_nonzero(self) -> bool: + """Whether or not the value of this symbol can ever be $0$. + + Notes: + Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`. + """ if self.exclude_zero: return True @@ -441,6 +451,18 @@ class SimSymbol(pyd.BaseModel): ) return check_real_domain(self.domain) + @functools.cached_property + def can_diff(self) -> bool: + """Whether this symbol can be used as the input / output variable when differentiating.""" + # Check Constants + ## -> Constants (w/pinned values) are never differentiable. + if self.is_constant: + return False + + # TODO: Discontinuities (especially across 0)? + + return self.mathtype in [spux.MathType.Real, spux.MathType.Complex] + #################### # - Properties #################### @@ -664,8 +686,10 @@ class SimSymbol(pyd.BaseModel): res = spux.strip_unit_system(sp_obj) # Broadcast Expansion - if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase): - res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols) + if (self.rows > 1 or self.cols > 1) and not isinstance( + res, sp.MatrixBase | sp.MatrixSymbol + ): + res = res * sp.ImmutableMatrix.ones(self.rows, self.cols) return res @@ -753,7 +777,9 @@ class SimSymbol(pyd.BaseModel): unit = None # Rows/Cols from Expr (if Matrix) - rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1) + rows, cols = ( + expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1) + ) return SimSymbol( sym_name=sym_name,