fix: map node up to current standards
parent
532f0246b5
commit
39747e2d68
|
@ -32,14 +32,14 @@ from ... import base, events
|
|||
|
||||
log = logger.get(__name__)
|
||||
|
||||
X_COMPLEX = sp.Symbol('x', complex=True)
|
||||
|
||||
|
||||
####################
|
||||
# - Operation Enum
|
||||
####################
|
||||
class MapOperation(enum.StrEnum):
|
||||
"""Valid operations for the `MapMathNode`.
|
||||
|
||||
Attributes:
|
||||
UserExpr: Use a user-provided mapping expression.
|
||||
Real: Compute the real part of the input.
|
||||
Imag: Compute the imaginary part of the input.
|
||||
Abs: Compute the absolute value of the input.
|
||||
|
@ -67,8 +67,6 @@ class MapOperation(enum.StrEnum):
|
|||
Svd: Compute the SVD-factorized matrices of the input matrix.
|
||||
"""
|
||||
|
||||
# By User Expression
|
||||
UserExpr = enum.auto()
|
||||
# By Number
|
||||
Real = enum.auto()
|
||||
Imag = enum.auto()
|
||||
|
@ -106,8 +104,6 @@ class MapOperation(enum.StrEnum):
|
|||
def to_name(value: typ.Self) -> str:
|
||||
MO = MapOperation
|
||||
return {
|
||||
# By User Expression
|
||||
MO.UserExpr: '*',
|
||||
# By Number
|
||||
MO.Real: 'ℝ(v)',
|
||||
MO.Imag: 'Im(v)',
|
||||
|
@ -159,11 +155,13 @@ class MapOperation(enum.StrEnum):
|
|||
@staticmethod
|
||||
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
|
||||
MO = MapOperation
|
||||
if shape == 'noshape':
|
||||
|
||||
match shape:
|
||||
case 'noshape':
|
||||
return []
|
||||
|
||||
# By Number
|
||||
if shape is None:
|
||||
case None:
|
||||
return [
|
||||
MO.Real,
|
||||
MO.Imag,
|
||||
|
@ -180,13 +178,14 @@ class MapOperation(enum.StrEnum):
|
|||
MO.Sinc,
|
||||
]
|
||||
|
||||
match len(shape):
|
||||
# By Vector
|
||||
if len(shape) == 1:
|
||||
case 1:
|
||||
return [
|
||||
MO.Norm2,
|
||||
]
|
||||
# By Matrix
|
||||
if len(shape) == 2:
|
||||
case 2:
|
||||
return [
|
||||
MO.Det,
|
||||
MO.Cond,
|
||||
|
@ -204,63 +203,92 @@ class MapOperation(enum.StrEnum):
|
|||
|
||||
return []
|
||||
|
||||
def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow):
|
||||
####################
|
||||
# - Function Properties
|
||||
####################
|
||||
@property
|
||||
def sp_func(self):
|
||||
MO = MapOperation
|
||||
if self == MO.UserExpr:
|
||||
return lambda data: user_expr_func.func(data)
|
||||
|
||||
msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`"
|
||||
raise ValueError(msg)
|
||||
return {
|
||||
# By Number
|
||||
MO.Real: lambda expr: sp.re(expr),
|
||||
MO.Imag: lambda expr: sp.im(expr),
|
||||
MO.Abs: lambda expr: sp.Abs(expr),
|
||||
MO.Sq: lambda expr: expr**2,
|
||||
MO.Sqrt: lambda expr: sp.sqrt(expr),
|
||||
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
|
||||
MO.Cos: lambda expr: sp.cos(expr),
|
||||
MO.Sin: lambda expr: sp.sin(expr),
|
||||
MO.Tan: lambda expr: sp.tan(expr),
|
||||
MO.Acos: lambda expr: sp.acos(expr),
|
||||
MO.Asin: lambda expr: sp.asin(expr),
|
||||
MO.Atan: lambda expr: sp.atan(expr),
|
||||
MO.Sinc: lambda expr: sp.sinc(expr),
|
||||
# By Vector
|
||||
# Vector -> Number
|
||||
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr),
|
||||
# By Matrix
|
||||
# Matrix -> Number
|
||||
MO.Det: lambda expr: sp.det(expr),
|
||||
MO.Cond: lambda expr: expr.condition_number(),
|
||||
MO.NormFro: lambda expr: expr.norm(ord='fro'),
|
||||
MO.Rank: lambda expr: expr.rank(),
|
||||
# Matrix -> Vec
|
||||
MO.Diag: lambda expr: expr.diagonal(),
|
||||
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
|
||||
MO.SvdVals: lambda expr: expr.singular_values(),
|
||||
# Matrix -> Matrix
|
||||
MO.Inv: lambda expr: expr.inv(),
|
||||
MO.Tra: lambda expr: expr.T,
|
||||
# Matrix -> Matrices
|
||||
MO.Qr: lambda expr: expr.QRdecomposition(),
|
||||
MO.Chol: lambda expr: expr.cholesky(),
|
||||
MO.Svd: lambda expr: expr.singular_value_decomposition(),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def jax_func(self):
|
||||
MO = MapOperation
|
||||
if self == MO.UserExpr:
|
||||
msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`"
|
||||
raise ValueError(msg)
|
||||
|
||||
return {
|
||||
# By Number
|
||||
MO.Real: lambda data: jnp.real(data),
|
||||
MO.Imag: lambda data: jnp.imag(data),
|
||||
MO.Abs: lambda data: jnp.abs(data),
|
||||
MO.Sq: lambda data: jnp.square(data),
|
||||
MO.Sqrt: lambda data: jnp.sqrt(data),
|
||||
MO.InvSqrt: lambda data: 1 / jnp.sqrt(data),
|
||||
MO.Cos: lambda data: jnp.cos(data),
|
||||
MO.Sin: lambda data: jnp.sin(data),
|
||||
MO.Tan: lambda data: jnp.tan(data),
|
||||
MO.Acos: lambda data: jnp.acos(data),
|
||||
MO.Asin: lambda data: jnp.asin(data),
|
||||
MO.Atan: lambda data: jnp.atan(data),
|
||||
MO.Sinc: lambda data: jnp.sinc(data),
|
||||
MO.Real: lambda expr: jnp.real(expr),
|
||||
MO.Imag: lambda expr: jnp.imag(expr),
|
||||
MO.Abs: lambda expr: jnp.abs(expr),
|
||||
MO.Sq: lambda expr: jnp.square(expr),
|
||||
MO.Sqrt: lambda expr: jnp.sqrt(expr),
|
||||
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
|
||||
MO.Cos: lambda expr: jnp.cos(expr),
|
||||
MO.Sin: lambda expr: jnp.sin(expr),
|
||||
MO.Tan: lambda expr: jnp.tan(expr),
|
||||
MO.Acos: lambda expr: jnp.acos(expr),
|
||||
MO.Asin: lambda expr: jnp.asin(expr),
|
||||
MO.Atan: lambda expr: jnp.atan(expr),
|
||||
MO.Sinc: lambda expr: jnp.sinc(expr),
|
||||
# By Vector
|
||||
# Vector -> Number
|
||||
MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
|
||||
MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
|
||||
# By Matrix
|
||||
# Matrix -> Number
|
||||
MO.Det: lambda data: jnp.linalg.det(data),
|
||||
MO.Cond: lambda data: jnp.linalg.cond(data),
|
||||
MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
||||
MO.Rank: lambda data: jnp.linalg.matrix_rank(data),
|
||||
MO.Det: lambda expr: jnp.linalg.det(expr),
|
||||
MO.Cond: lambda expr: jnp.linalg.cond(expr),
|
||||
MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
|
||||
MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
|
||||
# Matrix -> Vec
|
||||
MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
|
||||
MO.EigVals: lambda data: jnp.linalg.eigvals(data),
|
||||
MO.SvdVals: lambda data: jnp.linalg.svdvals(data),
|
||||
MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
|
||||
MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
|
||||
MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
|
||||
# Matrix -> Matrix
|
||||
MO.Inv: lambda data: jnp.linalg.inv(data),
|
||||
MO.Tra: lambda data: jnp.matrix_transpose(data),
|
||||
MO.Inv: lambda expr: jnp.linalg.inv(expr),
|
||||
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
|
||||
# Matrix -> Matrices
|
||||
MO.Qr: lambda data: jnp.linalg.qr(data),
|
||||
MO.Chol: lambda data: jnp.linalg.cholesky(data),
|
||||
MO.Svd: lambda data: jnp.linalg.svd(data),
|
||||
MO.Qr: lambda expr: jnp.linalg.qr(expr),
|
||||
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
|
||||
MO.Svd: lambda expr: jnp.linalg.svd(expr),
|
||||
}[self]
|
||||
|
||||
def transform_info(self, info: ct.InfoFlow):
|
||||
MO = MapOperation
|
||||
return {
|
||||
# By User Expression
|
||||
MO.UserExpr: '*',
|
||||
# By Number
|
||||
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||
|
@ -294,9 +322,12 @@ class MapOperation(enum.StrEnum):
|
|||
),
|
||||
## TODO: Matrix -> Vec
|
||||
## TODO: Matrix -> Matrices
|
||||
}.get(self, info)()
|
||||
}.get(self, lambda: info)()
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class MapMathNode(base.MaxwellSimNode):
|
||||
r"""Applies a function by-structure to the data.
|
||||
|
||||
|
@ -390,11 +421,23 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: MapOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations()
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
input_sockets_optional={'Expr': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
info_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
@property
|
||||
if has_info and not info_pending:
|
||||
self.expr_output_shape = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_output_shape(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
|
@ -403,6 +446,11 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
|
||||
return 'noshape'
|
||||
|
||||
operation: MapOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_output_shape'},
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_output_shape != 'noshape':
|
||||
return [
|
||||
|
@ -426,110 +474,58 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name='Expr',
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_input_changed(self):
|
||||
if self.operation not in MapOperation.by_element_shape(self.expr_output_shape):
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
prop_name={'operation'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Value,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
)
|
||||
def on_operation_changed(self, props: dict) -> None:
|
||||
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
# UserExpr: Add/Remove Input Socket
|
||||
if operation == MapOperation.UserExpr:
|
||||
current_bl_socket = self.loose_input_sockets.get('Mapper')
|
||||
if current_bl_socket is None:
|
||||
self.loose_input_sockets = {
|
||||
'Mapper': sockets.ExprSocketDef(
|
||||
symbols={X_COMPLEX},
|
||||
default_value=X_COMPLEX,
|
||||
mathtype=spux.MathType.Complex,
|
||||
),
|
||||
}
|
||||
has_expr_value = not ct.FlowSignal.check(expr)
|
||||
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_value and operation is not None:
|
||||
operation.sp_func(expr)
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Compute: LazyValueFunc / Array
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr', 'Mapper'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.LazyValueFunc,
|
||||
'Mapper': ct.FlowKind.LazyValueFunc,
|
||||
},
|
||||
input_sockets_optional={'Mapper': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
def compute_func(
|
||||
self, props, input_sockets
|
||||
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
mapper = input_sockets['Mapper']
|
||||
|
||||
has_expr = not ct.FlowSignal.check(expr)
|
||||
has_mapper = not ct.FlowSignal.check(mapper)
|
||||
|
||||
if has_expr and operation is not None:
|
||||
if not has_mapper:
|
||||
return expr.compose_within(
|
||||
operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
if operation == MapOperation.UserExpr and has_mapper:
|
||||
return expr.compose_within(
|
||||
operation.jax_func_expr(user_expr_func=mapper),
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
)
|
||||
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
if has_lazy_value_func and has_params:
|
||||
unit_system = unit_systems['BlenderUnits']
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.scaled_func_args(unit_system),
|
||||
**params.scaled_func_kwargs(unit_system),
|
||||
),
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Compute Auxiliary: Info / Params
|
||||
# - FlowKind.Info|Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'active_socket_set', 'operation'},
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue