fix: operate node and "remembered" math ops
We reimplemented the OperateMathNode entirely, and it should now be relatively to entirely robust. It works based on an enum, just like `MapMathNode`, which enormously simplifies the node itself. Note that we're not quite done with sympy+jax implementations of everything, nor validating that all operations have an implementation valid for all the `InfoFlow`s that prompted it, but, nonetheless, such fixes should be easy to work out as we go. We also finally managed to implement "remembering". In essence, "remembering" causes `FlowPending` to not reset the dynamic enums in math / extract nodes. The implementation is generally per-node, and a bit of boilerplate, but it seems quite robust on the whole - allowing one to "keep a node tree around" even after removing simulation data. Extract of sim data is of course more leniant, as it also "remembers" when there is `NoFlow`. Only new `Sim Data` will reset the extraction filter enum. Caches are kept (not memory efficient, but convenient for me), but only `FlowPending` is actually passed along the `InfoFlow`. This is the semantics we want - having to deal with this all is a tradeoff of having data-driven enums, but all in all, it still feels like the right approach.main
parent
035d8971f3
commit
532f0246b5
|
@ -80,7 +80,6 @@ class FlowKind(enum.StrEnum):
|
|||
unit_system,
|
||||
)
|
||||
if kind == FlowKind.LazyArrayRange:
|
||||
log.debug([kind, flow_obj, unit_system])
|
||||
return flow_obj.rescale_to_unit_system(unit_system)
|
||||
|
||||
if kind == FlowKind.Params:
|
||||
|
|
|
@ -74,6 +74,12 @@ class InfoFlow:
|
|||
output_mathtype: spux.MathType = dataclasses.field()
|
||||
output_unit: spux.Unit | None = dataclasses.field()
|
||||
|
||||
@property
|
||||
def output_shape_len(self) -> int:
|
||||
if self.output_shape is None:
|
||||
return 0
|
||||
return len(self.output_shape)
|
||||
|
||||
# Pinned Dimension Information
|
||||
## TODO: Add PhysicalType
|
||||
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||
|
|
|
@ -477,7 +477,7 @@ def populate_missing_persistence(_) -> None:
|
|||
# - Blender Registration
|
||||
####################
|
||||
bpy.app.handlers.load_post.append(initialize_sim_tree_node_link_cache)
|
||||
bpy.app.handlers.load_post.append(populate_missing_persistence)
|
||||
# bpy.app.handlers.load_post.append(populate_missing_persistence)
|
||||
## TODO: Move to top-level registration.
|
||||
|
||||
BL_REGISTER = [
|
||||
|
|
|
@ -73,9 +73,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Computed: Sim Data
|
||||
####################
|
||||
@events.on_value_changed(socket_name='Sim Data')
|
||||
def on_sim_data_changed(self) -> None: # noqa: D102
|
||||
log.critical('On Value Changed: Sim Data')
|
||||
@events.on_value_changed(
|
||||
socket_name='Sim Data',
|
||||
input_sockets={'Sim Data'},
|
||||
input_sockets_optional={'Sim Data': True},
|
||||
)
|
||||
def on_sim_data_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_sim_data = not ct.FlowSignal.check(input_sockets['Sim Data'])
|
||||
if has_sim_data:
|
||||
self.sim_data = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
|
@ -112,9 +117,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Computed Properties: Monitor Data
|
||||
####################
|
||||
@events.on_value_changed(socket_name='Monitor Data')
|
||||
def on_monitor_data_changed(self) -> None: # noqa: D102
|
||||
log.critical('On Value Changed: Sim Data')
|
||||
@events.on_value_changed(
|
||||
socket_name='Monitor Data',
|
||||
input_sockets={'Monitor Data'},
|
||||
input_sockets_optional={'Monitor Data': True},
|
||||
)
|
||||
def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data'])
|
||||
if has_monitor_data:
|
||||
self.monitor_data = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
|
@ -319,6 +329,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
# Loaded
|
||||
props={'extract_filter'},
|
||||
input_sockets={'Sim Data'},
|
||||
input_sockets_optional={'Sim Data': True},
|
||||
)
|
||||
def compute_monitor_data(
|
||||
self, props: dict, input_sockets: dict
|
||||
|
@ -347,6 +358,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
props={'extract_filter'},
|
||||
input_sockets={'Monitor Data'},
|
||||
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
|
||||
input_sockets_optional={'Monitor Data': True},
|
||||
)
|
||||
def compute_expr(
|
||||
self, props: dict, input_sockets: dict
|
||||
|
@ -376,6 +388,7 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
# Loaded
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Array},
|
||||
output_sockets_optional={'Expr': True},
|
||||
)
|
||||
def compute_extracted_data_lazy(
|
||||
self, output_sockets: dict
|
||||
|
@ -436,9 +449,18 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
|
||||
has_monitor_data = not ct.FlowSignal.check(monitor_data)
|
||||
|
||||
# Edge Case: Dangling 'flux' Access on 'FieldMonitor'
|
||||
## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields.
|
||||
## -> We don't allow 'flux' attribute access, but it can dangle.
|
||||
## -> (The method is called when updating each depschain component.)
|
||||
if monitor_data_type == 'Field' and extract_filter == 'flux':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Retrieve XArray
|
||||
if has_monitor_data and extract_filter is not None:
|
||||
xarr = getattr(monitor_data, extract_filter)
|
||||
xarr = getattr(monitor_data, extract_filter, None)
|
||||
if xarr is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
else:
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
|
|
@ -99,6 +99,9 @@ class MapOperation(enum.StrEnum):
|
|||
Chol = enum.auto()
|
||||
Svd = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
MO = MapOperation
|
||||
|
@ -150,6 +153,9 @@ class MapOperation(enum.StrEnum):
|
|||
i,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Ops from Shape
|
||||
####################
|
||||
@staticmethod
|
||||
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
|
||||
MO = MapOperation
|
||||
|
@ -198,10 +204,21 @@ class MapOperation(enum.StrEnum):
|
|||
|
||||
return []
|
||||
|
||||
def jax_func(self, user_expr_func: ct.LazyValueFuncFlow | None = None):
|
||||
def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow):
|
||||
MO = MapOperation
|
||||
if self == MO.UserExpr and user_expr_func is not None:
|
||||
if self == MO.UserExpr:
|
||||
return lambda data: user_expr_func.func(data)
|
||||
|
||||
msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`"
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def jax_func(self):
|
||||
MO = MapOperation
|
||||
if self == MO.UserExpr:
|
||||
msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`"
|
||||
raise ValueError(msg)
|
||||
|
||||
return {
|
||||
# By Number
|
||||
MO.Real: lambda data: jnp.real(data),
|
||||
|
@ -386,8 +403,6 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
|
||||
return 'noshape'
|
||||
|
||||
output_shape: tuple[int, ...] | None = bl_cache.BLField(None)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_output_shape != 'noshape':
|
||||
return [
|
||||
|
@ -472,12 +487,12 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
if has_expr and operation is not None:
|
||||
if not has_mapper:
|
||||
return expr.compose_within(
|
||||
operation.jax_func(),
|
||||
operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
if operation == MapOperation.UserExpr and has_mapper:
|
||||
return expr.compose_within(
|
||||
operation.jax_func(user_expr_func=mapper),
|
||||
operation.jax_func_expr(user_expr_func=mapper),
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -29,37 +29,226 @@ from ... import base, events
|
|||
|
||||
log = logger.get(__name__)
|
||||
|
||||
FUNCS = {
|
||||
|
||||
####################
|
||||
# - Operation Enum
|
||||
####################
|
||||
class BinaryOperation(enum.StrEnum):
|
||||
"""Valid operations for the `OperateMathNode`.
|
||||
|
||||
Attributes:
|
||||
Add: Addition w/broadcasting.
|
||||
Sub: Subtraction w/broadcasting.
|
||||
Mul: Hadamard-product multiplication.
|
||||
Div: Hadamard-product based division.
|
||||
Pow: Elementwise expontiation.
|
||||
Atan2: Quadrant-respecting arctangent variant.
|
||||
VecVecDot: Dot product for vectors.
|
||||
Cross: Cross product.
|
||||
MatVecDot: Matrix-Vector dot product.
|
||||
LinSolve: Solve a linear system.
|
||||
LsqSolve: Minimize error of an underdetermined linear system.
|
||||
MatMatDot: Matrix-Matrix dot product.
|
||||
"""
|
||||
|
||||
# Number | Number
|
||||
'ADD': lambda exprs: exprs[0] + exprs[1],
|
||||
'SUB': lambda exprs: exprs[0] - exprs[1],
|
||||
'MUL': lambda exprs: exprs[0] * exprs[1],
|
||||
'DIV': lambda exprs: exprs[0] / exprs[1],
|
||||
'POW': lambda exprs: exprs[0] ** exprs[1],
|
||||
'ATAN2': lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||
# Vector | Vector
|
||||
'VEC_VEC_DOT': lambda exprs: exprs[0].dot(exprs[1]),
|
||||
'CROSS': lambda exprs: exprs[0].cross(exprs[1]),
|
||||
}
|
||||
Add = enum.auto()
|
||||
Sub = enum.auto()
|
||||
Mul = enum.auto()
|
||||
Div = enum.auto()
|
||||
Pow = enum.auto()
|
||||
Atan2 = enum.auto()
|
||||
|
||||
SP_FUNCS = FUNCS
|
||||
JAX_FUNCS = FUNCS | {
|
||||
# Number | *
|
||||
'ATAN2': lambda exprs: jnp.atan2(exprs[1], exprs[0]),
|
||||
# Vector | Vector
|
||||
'VEC_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
'CROSS': lambda exprs: jnp.cross(exprs[0], exprs[1]),
|
||||
VecVecDot = enum.auto()
|
||||
Cross = enum.auto()
|
||||
|
||||
# Matrix | Vector
|
||||
'MAT_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
'LIN_SOLVE': lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
|
||||
'LSQ_SOLVE': lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
|
||||
MatVecDot = enum.auto()
|
||||
LinSolve = enum.auto()
|
||||
LsqSolve = enum.auto()
|
||||
|
||||
# Matrix | Matrix
|
||||
'MAT_MAT_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
}
|
||||
MatMatDot = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
BO = BinaryOperation
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Add: 'ℓ + r',
|
||||
BO.Sub: 'ℓ - r',
|
||||
BO.Mul: 'ℓ ⊙ r', ## Notation for Hadamard Product
|
||||
BO.Div: 'ℓ / r',
|
||||
BO.Pow: 'ℓʳ',
|
||||
BO.Atan2: 'atan2(ℓ,r)',
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: '𝐥 · 𝐫',
|
||||
BO.Cross: 'cross(L,R)',
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: '𝐋 · 𝐫',
|
||||
BO.LinSolve: '𝐋 ∖ 𝐫',
|
||||
BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: '𝐋 · 𝐑',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
||||
BO = BinaryOperation
|
||||
return (
|
||||
str(self),
|
||||
BO.to_name(self),
|
||||
BO.to_name(self),
|
||||
BO.to_icon(self),
|
||||
i,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Ops from Shape
|
||||
####################
|
||||
@staticmethod
|
||||
def by_infos(info_l: int, info_r: int) -> list[typ.Self]:
|
||||
"""Deduce valid binary operations from the shapes of the inputs."""
|
||||
BO = BinaryOperation
|
||||
|
||||
ops_number_number = [
|
||||
BO.Add,
|
||||
BO.Sub,
|
||||
BO.Mul,
|
||||
BO.Div,
|
||||
BO.Pow,
|
||||
BO.Atan2,
|
||||
]
|
||||
|
||||
match (info_l.output_shape_len, info_r.output_shape_len):
|
||||
# Number | *
|
||||
## Number | Number
|
||||
case (0, 0):
|
||||
return ops_number_number
|
||||
|
||||
## Number | Vector
|
||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
||||
case (0, 1):
|
||||
return ops_number_number
|
||||
|
||||
## Number | Matrix
|
||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
||||
case (0, 2):
|
||||
return ops_number_number
|
||||
|
||||
# Vector | *
|
||||
## Vector | Number
|
||||
case (1, 0):
|
||||
return ops_number_number
|
||||
|
||||
## Vector | Number
|
||||
case (1, 1):
|
||||
return [*ops_number_number, BO.VecVecDot, BO.Cross]
|
||||
|
||||
## Vector | Matrix
|
||||
case (1, 2):
|
||||
return []
|
||||
|
||||
# Matrix | *
|
||||
## Matrix | Number
|
||||
case (2, 0):
|
||||
return [*ops_number_number, BO.MatMatDot]
|
||||
|
||||
## Matrix | Vector
|
||||
case (2, 1):
|
||||
return [BO.MatVecDot, BO.LinSolve, BO.LsqSolve]
|
||||
|
||||
## Matrix | Matrix
|
||||
case (2, 2):
|
||||
return [*ops_number_number, BO.MatMatDot]
|
||||
|
||||
return []
|
||||
|
||||
####################
|
||||
# - Function Properties
|
||||
####################
|
||||
@property
|
||||
def sp_func(self):
|
||||
"""Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
|
||||
BO = BinaryOperation
|
||||
|
||||
## TODO: Make this compatible with sp.Matrix inputs
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def jax_func(self):
|
||||
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
|
||||
BO = BinaryOperation
|
||||
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
|
||||
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
|
||||
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - InfoFlow Transform
|
||||
####################
|
||||
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
||||
BO = BinaryOperation
|
||||
|
||||
info_largest = (
|
||||
info_l if info_l.output_shape_len > info_l.output_shape_len else info_l
|
||||
)
|
||||
info_any = info_largest
|
||||
return {
|
||||
# Number | * or * | Number
|
||||
BO.Add: info_largest,
|
||||
BO.Sub: info_largest,
|
||||
BO.Mul: info_largest,
|
||||
BO.Div: info_largest,
|
||||
BO.Pow: info_largest,
|
||||
BO.Atan2: info_largest,
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: info_any,
|
||||
BO.Cross: info_any,
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: info_r,
|
||||
BO.LinSolve: info_r,
|
||||
BO.LsqSolve: info_r,
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: info_any,
|
||||
}[self]
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class OperateMathNode(base.MaxwellSimNode):
|
||||
r"""Applies a function that depends on two inputs.
|
||||
r"""Applies a binary function between two expressions.
|
||||
|
||||
Attributes:
|
||||
category: The category of operations to apply to the inputs.
|
||||
|
@ -76,196 +265,86 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc, show_info_columns=True
|
||||
),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
category: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_categories()
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr L', 'Expr R'},
|
||||
input_sockets={'Expr L', 'Expr R'},
|
||||
input_socket_kinds={'Expr L': ct.FlowKind.Info, 'Expr R': ct.FlowKind.Info},
|
||||
input_sockets_optional={'Expr L': True, 'Expr R': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info_l = not ct.FlowSignal.check(input_sockets['Expr L'])
|
||||
has_info_r = not ct.FlowSignal.check(input_sockets['Expr R'])
|
||||
|
||||
info_l_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr L'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
info_r_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr R'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
operation: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations()
|
||||
if has_info_l and has_info_r and not info_l_pending and not info_r_pending:
|
||||
self.expr_infos = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None:
|
||||
info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info)
|
||||
info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info)
|
||||
|
||||
has_info_l = not ct.FlowSignal.check(info_l)
|
||||
has_info_r = not ct.FlowSignal.check(info_r)
|
||||
|
||||
if has_info_l and has_info_r:
|
||||
return (info_l, info_r)
|
||||
|
||||
return None
|
||||
|
||||
operation: BinaryOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_infos'},
|
||||
)
|
||||
|
||||
def search_categories(self) -> list[ct.BLEnumElement]:
|
||||
"""Deduce and return a list of valid categories for the current socket set and input data."""
|
||||
expr_l_info = self._compute_input(
|
||||
'Expr L',
|
||||
kind=ct.FlowKind.Info,
|
||||
)
|
||||
expr_r_info = self._compute_input(
|
||||
'Expr R',
|
||||
kind=ct.FlowKind.Info,
|
||||
)
|
||||
|
||||
has_expr_l_info = not ct.FlowSignal.check(expr_l_info)
|
||||
has_expr_r_info = not ct.FlowSignal.check(expr_r_info)
|
||||
|
||||
# Categories by Socket Set
|
||||
NUMBER_NUMBER = (
|
||||
'Number | Number',
|
||||
'Number | Number',
|
||||
'Operations between numerical elements',
|
||||
)
|
||||
NUMBER_VECTOR = (
|
||||
'Number | Vector',
|
||||
'Number | Vector',
|
||||
'Operations between numerical and vector elements',
|
||||
)
|
||||
NUMBER_MATRIX = (
|
||||
'Number | Matrix',
|
||||
'Number | Matrix',
|
||||
'Operations between numerical and matrix elements',
|
||||
)
|
||||
VECTOR_VECTOR = (
|
||||
'Vector | Vector',
|
||||
'Vector | Vector',
|
||||
'Operations between vector elements',
|
||||
)
|
||||
MATRIX_VECTOR = (
|
||||
'Matrix | Vector',
|
||||
'Matrix | Vector',
|
||||
'Operations between vector and matrix elements',
|
||||
)
|
||||
MATRIX_MATRIX = (
|
||||
'Matrix | Matrix',
|
||||
'Matrix | Matrix',
|
||||
'Operations between matrix elements',
|
||||
)
|
||||
categories = []
|
||||
|
||||
if has_expr_l_info and has_expr_r_info:
|
||||
# Check Valid Broadcasting
|
||||
## Number | Number
|
||||
if expr_l_info.output_shape is None and expr_r_info.output_shape is None:
|
||||
categories = [NUMBER_NUMBER]
|
||||
|
||||
## * | Number
|
||||
elif expr_r_info.output_shape is None:
|
||||
categories = []
|
||||
|
||||
## Number | Vector
|
||||
elif (
|
||||
expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [NUMBER_VECTOR]
|
||||
|
||||
## Number | Matrix
|
||||
elif (
|
||||
expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 2
|
||||
):
|
||||
categories = [NUMBER_MATRIX]
|
||||
|
||||
## Vector | Vector
|
||||
elif (
|
||||
len(expr_l_info.output_shape) == 1
|
||||
and len(expr_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [VECTOR_VECTOR]
|
||||
|
||||
## Matrix | Vector
|
||||
elif (
|
||||
len(expr_l_info.output_shape) == 2 # noqa: PLR2004
|
||||
and len(expr_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [MATRIX_VECTOR]
|
||||
|
||||
## Matrix | Matrix
|
||||
elif (
|
||||
len(expr_l_info.output_shape) == 2 # noqa: PLR2004
|
||||
and len(expr_r_info.output_shape) == 2 # noqa: PLR2004
|
||||
):
|
||||
categories = [MATRIX_MATRIX]
|
||||
|
||||
return [
|
||||
(*category, '', i) if category is not None else None
|
||||
for i, category in enumerate(categories)
|
||||
]
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
items = []
|
||||
if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']:
|
||||
items += [
|
||||
('ADD', 'L + R', 'Add'),
|
||||
('SUB', 'L - R', 'Subtract'),
|
||||
('MUL', 'L · R', 'Multiply'),
|
||||
('DIV', 'L / R', 'Divide'),
|
||||
('POW', 'L^R', 'Power'),
|
||||
('ATAN2', 'atan2(L,R)', 'atan2(L,R)'),
|
||||
]
|
||||
if self.category == 'Vector | Vector':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'),
|
||||
('CROSS', 'L x R', 'Cross Product'),
|
||||
]
|
||||
if self.category == 'Matrix | Vector':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'),
|
||||
('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'),
|
||||
('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'),
|
||||
]
|
||||
if self.category == 'Matrix | Matrix':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'),
|
||||
]
|
||||
|
||||
if self.expr_infos is not None:
|
||||
return [
|
||||
(*item, '', i) if item is not None else None for i, item in enumerate(items)
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(
|
||||
BinaryOperation.by_infos(*self.expr_infos)
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
labels = {
|
||||
'ADD': lambda: 'L + R',
|
||||
'SUB': lambda: 'L - R',
|
||||
'MUL': lambda: 'L · R',
|
||||
'DIV': lambda: 'L / R',
|
||||
'POW': lambda: 'L^R',
|
||||
'ATAN2': lambda: 'atan2(L,R)',
|
||||
}
|
||||
"""Show the current operation (if any) in the node's header label.
|
||||
|
||||
if (label := labels.get(self.operation)) is not None:
|
||||
return 'Operate: ' + label()
|
||||
Notes:
|
||||
Called by Blender to determine the text to place in the node's header.
|
||||
"""
|
||||
if self.operation is not None:
|
||||
return 'Op: ' + BinaryOperation.to_name(self.operation)
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['category'], text='')
|
||||
"""Draw node properties in the node.
|
||||
|
||||
Parameters:
|
||||
col: UI target for drawing.
|
||||
"""
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name={'Expr L', 'Expr R'},
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_socket_changed(self) -> None:
|
||||
# Recompute Valid Categories
|
||||
self.category = bl_cache.Signal.ResetEnumItems
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
prop_name='category',
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_category_changed(self) -> None:
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
####################
|
||||
# - Output
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -285,8 +364,10 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
has_expr_l_value = not ct.FlowSignal.check(expr_l)
|
||||
has_expr_r_value = not ct.FlowSignal.check(expr_r)
|
||||
|
||||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_l_value and has_expr_r_value and operation is not None:
|
||||
return SP_FUNCS[operation]([expr_l, expr_r])
|
||||
operation.sp_func([expr_l, expr_r])
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
|
||||
|
@ -311,43 +392,17 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||
has_expr_r = not ct.FlowSignal.check(expr_r)
|
||||
|
||||
# Compute Jax Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_l and has_expr_r:
|
||||
return (expr_l | expr_r).compose_within(
|
||||
JAX_FUNCS[operation],
|
||||
operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
)
|
||||
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
if has_lazy_value_func and has_params:
|
||||
unit_system = unit_systems['BlenderUnits']
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.scaled_func_args(unit_system),
|
||||
**params.scaled_func_kwargs(unit_system),
|
||||
),
|
||||
unit=None,
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Info
|
||||
# - FlowKind.Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -367,17 +422,20 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
has_info_l = not ct.FlowSignal.check(info_l)
|
||||
has_info_r = not ct.FlowSignal.check(info_r)
|
||||
|
||||
# Return Info of RHS
|
||||
## -> Fundamentall, this is why 'category' only has the given options.
|
||||
## -> Via 'category', we enforce that the operated-on structure is always RHS.
|
||||
## -> That makes it super duper easy to track info changes.
|
||||
if has_info_l and has_info_r and operation is not None:
|
||||
return info_r
|
||||
# Compute Info
|
||||
## -> The operation enum directly provides the appropriate transform.
|
||||
if (
|
||||
has_info_l
|
||||
and has_info_r
|
||||
and operation is not None
|
||||
and operation in BinaryOperation.by_infos(info_l, info_r)
|
||||
):
|
||||
return operation.transform_infos(info_l, info_r)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Params
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -397,8 +455,11 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
has_params_l = not ct.FlowSignal.check(params_l)
|
||||
has_params_r = not ct.FlowSignal.check(params_r)
|
||||
|
||||
# Compute Params
|
||||
## -> Operations don't add new parameters, so just concatenate L|R.
|
||||
if has_params_l and has_params_r and operation is not None:
|
||||
return params_l | params_r
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
|
|
|
@ -186,7 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
# UI: Info
|
||||
show_info_columns: bool = bl_cache.BLField(False)
|
||||
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
|
||||
{InfoDisplayCol.MathType, InfoDisplayCol.Unit}
|
||||
{InfoDisplayCol.Length, InfoDisplayCol.MathType}
|
||||
)
|
||||
|
||||
####################
|
||||
|
|
|
@ -329,8 +329,8 @@ class BLField:
|
|||
# Retrieve Old Enum Item
|
||||
## -> This is verbatim what is being used.
|
||||
## -> De-Coerce None -> 'NONE' to avoid special-cased search.
|
||||
_old_item = self.bl_prop.read(bl_instance)
|
||||
old_item = 'NONE' if _old_item is None else _old_item
|
||||
old_item = self.bl_prop.read(bl_instance)
|
||||
raw_old_item = 'NONE' if old_item is None else str(old_item)
|
||||
|
||||
# Swap Enum Items
|
||||
## -> This is the hot stuff - the enum elements are overwritten.
|
||||
|
@ -343,7 +343,7 @@ class BLField:
|
|||
## -> If so, the user will expect it to "remain".
|
||||
## -> Thus, we set it - Blender sees a change, user doesn't.
|
||||
## -> DO NOT trigger on_prop_changed (since "nothing changed").
|
||||
if any(old_item == item[0] for item in current_items):
|
||||
if any(raw_old_item == item[0] for item in current_items):
|
||||
self.suppress_next_update(bl_instance)
|
||||
self.bl_prop.write(bl_instance, old_item)
|
||||
## -> TODO: Don't write if not needed.
|
||||
|
@ -352,10 +352,8 @@ class BLField:
|
|||
## -> In this case, fallback to the first current item.
|
||||
## -> DO trigger on_prop_changed (since it changed!)
|
||||
else:
|
||||
_first_current_item = current_items[0][0]
|
||||
first_current_item = (
|
||||
_first_current_item if _first_current_item != 'NONE' else None
|
||||
)
|
||||
raw_first_current_item = current_items[0][0]
|
||||
first_current_item = self.bl_prop.decode(raw_first_current_item)
|
||||
|
||||
self.suppress_next_update(bl_instance)
|
||||
self.bl_prop.write(bl_instance, first_current_item)
|
||||
|
|
|
@ -198,7 +198,12 @@ class BLProp:
|
|||
)
|
||||
|
||||
def read(self, bl_instance: bl_instance.BLInstance) -> typ.Any:
|
||||
"""Read the Blender property's particular value on the given `bl_instance`."""
|
||||
"""Read the persisted Blender property value for this property, from a particular `BLInstance`.
|
||||
|
||||
Parameters:
|
||||
bl_instance: The Blender object to
|
||||
**NOTE**: `bl_instance` must not be `None`, as neighboring methods sometimes allow.
|
||||
"""
|
||||
persisted_value = self.decode(
|
||||
managed_cache.read(
|
||||
bl_instance,
|
||||
|
|
|
@ -50,6 +50,7 @@ class BLInstance:
|
|||
# - Attributes
|
||||
####################
|
||||
instance_id: bpy.props.StringProperty(default='')
|
||||
is_updating: bpy.props.BoolProperty(default=False)
|
||||
|
||||
blfields: typ.ClassVar[dict[str, str]] = MappingProxyType({})
|
||||
blfield_deps: typ.ClassVar[dict[str, list[str]]] = MappingProxyType({})
|
||||
|
|
Loading…
Reference in New Issue