fix: map node up to current standards

main
Sofus Albert Høgsbro Rose 2024-05-18 18:42:54 +02:00
parent 532f0246b5
commit 39747e2d68
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
1 changed files with 160 additions and 164 deletions

View File

@ -32,14 +32,14 @@ from ... import base, events
log = logger.get(__name__)
X_COMPLEX = sp.Symbol('x', complex=True)
####################
# - Operation Enum
####################
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
Attributes:
UserExpr: Use a user-provided mapping expression.
Real: Compute the real part of the input.
Imag: Compute the imaginary part of the input.
Abs: Compute the absolute value of the input.
@ -67,8 +67,6 @@ class MapOperation(enum.StrEnum):
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
# By User Expression
UserExpr = enum.auto()
# By Number
Real = enum.auto()
Imag = enum.auto()
@ -106,8 +104,6 @@ class MapOperation(enum.StrEnum):
def to_name(value: typ.Self) -> str:
MO = MapOperation
return {
# By User Expression
MO.UserExpr: '*',
# By Number
MO.Real: '(v)',
MO.Imag: 'Im(v)',
@ -159,108 +155,140 @@ class MapOperation(enum.StrEnum):
@staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
MO = MapOperation
if shape == 'noshape':
return []
# By Number
if shape is None:
return [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match shape:
case 'noshape':
return []
# By Vector
if len(shape) == 1:
return [
MO.Norm2,
]
# By Matrix
if len(shape) == 2:
return [
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
# By Number
case None:
return [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match len(shape):
# By Vector
case 1:
return [
MO.Norm2,
]
# By Matrix
case 2:
return [
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
return []
def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow):
####################
# - Function Properties
####################
@property
def sp_func(self):
MO = MapOperation
if self == MO.UserExpr:
return lambda data: user_expr_func.func(data)
msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`"
raise ValueError(msg)
return {
# By Number
MO.Real: lambda expr: sp.re(expr),
MO.Imag: lambda expr: sp.im(expr),
MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
MO.Tan: lambda expr: sp.tan(expr),
MO.Acos: lambda expr: sp.acos(expr),
MO.Asin: lambda expr: sp.asin(expr),
MO.Atan: lambda expr: sp.atan(expr),
MO.Sinc: lambda expr: sp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr),
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: sp.det(expr),
MO.Cond: lambda expr: expr.condition_number(),
MO.NormFro: lambda expr: expr.norm(ord='fro'),
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
MO.Qr: lambda expr: expr.QRdecomposition(),
MO.Chol: lambda expr: expr.cholesky(),
MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property
def jax_func(self):
MO = MapOperation
if self == MO.UserExpr:
msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`"
raise ValueError(msg)
return {
# By Number
MO.Real: lambda data: jnp.real(data),
MO.Imag: lambda data: jnp.imag(data),
MO.Abs: lambda data: jnp.abs(data),
MO.Sq: lambda data: jnp.square(data),
MO.Sqrt: lambda data: jnp.sqrt(data),
MO.InvSqrt: lambda data: 1 / jnp.sqrt(data),
MO.Cos: lambda data: jnp.cos(data),
MO.Sin: lambda data: jnp.sin(data),
MO.Tan: lambda data: jnp.tan(data),
MO.Acos: lambda data: jnp.acos(data),
MO.Asin: lambda data: jnp.asin(data),
MO.Atan: lambda data: jnp.atan(data),
MO.Sinc: lambda data: jnp.sinc(data),
MO.Real: lambda expr: jnp.real(expr),
MO.Imag: lambda expr: jnp.imag(expr),
MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda expr: jnp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda expr: jnp.sin(expr),
MO.Tan: lambda expr: jnp.tan(expr),
MO.Acos: lambda expr: jnp.acos(expr),
MO.Asin: lambda expr: jnp.asin(expr),
MO.Atan: lambda expr: jnp.atan(expr),
MO.Sinc: lambda expr: jnp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
# By Matrix
# Matrix -> Number
MO.Det: lambda data: jnp.linalg.det(data),
MO.Cond: lambda data: jnp.linalg.cond(data),
MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
MO.Rank: lambda data: jnp.linalg.matrix_rank(data),
MO.Det: lambda expr: jnp.linalg.det(expr),
MO.Cond: lambda expr: jnp.linalg.cond(expr),
MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
# Matrix -> Vec
MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
MO.EigVals: lambda data: jnp.linalg.eigvals(data),
MO.SvdVals: lambda data: jnp.linalg.svdvals(data),
MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
# Matrix -> Matrix
MO.Inv: lambda data: jnp.linalg.inv(data),
MO.Tra: lambda data: jnp.matrix_transpose(data),
MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices
MO.Qr: lambda data: jnp.linalg.qr(data),
MO.Chol: lambda data: jnp.linalg.cholesky(data),
MO.Svd: lambda data: jnp.linalg.svd(data),
MO.Qr: lambda expr: jnp.linalg.qr(expr),
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self]
def transform_info(self, info: ct.InfoFlow):
MO = MapOperation
return {
# By User Expression
MO.UserExpr: '*',
# By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
@ -294,9 +322,12 @@ class MapOperation(enum.StrEnum):
),
## TODO: Matrix -> Vec
## TODO: Matrix -> Matrices
}.get(self, info)()
}.get(self, lambda: info)()
####################
# - Node
####################
class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data.
@ -390,11 +421,23 @@ class MapMathNode(base.MaxwellSimNode):
####################
# - Properties
####################
operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations()
@events.on_value_changed(
socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
@property
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
if has_info and not info_pending:
self.expr_output_shape = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
has_info = not ct.FlowSignal.check(info)
@ -403,6 +446,11 @@ class MapMathNode(base.MaxwellSimNode):
return 'noshape'
operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_output_shape'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape != 'noshape':
return [
@ -426,110 +474,58 @@ class MapMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='')
####################
# - Events
# - FlowKind.Value|LazyValueFunc
####################
@events.on_value_changed(
# Trigger
socket_name='Expr',
run_on_init=True,
)
def on_input_changed(self):
if self.operation not in MapOperation.by_element_shape(self.expr_output_shape):
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
prop_name={'operation'},
run_on_init=True,
# Loaded
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
props={'operation'},
input_sockets={'Expr'},
)
def on_operation_changed(self, props: dict) -> None:
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
# UserExpr: Add/Remove Input Socket
if operation == MapOperation.UserExpr:
current_bl_socket = self.loose_input_sockets.get('Mapper')
if current_bl_socket is None:
self.loose_input_sockets = {
'Mapper': sockets.ExprSocketDef(
symbols={X_COMPLEX},
default_value=X_COMPLEX,
mathtype=spux.MathType.Complex,
),
}
has_expr_value = not ct.FlowSignal.check(expr)
elif self.loose_input_sockets:
self.loose_input_sockets = {}
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None:
operation.sp_func(expr)
return ct.Flowsignal.FlowPending
####################
# - Compute: LazyValueFunc / Array
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.LazyValueFunc,
props={'operation'},
input_sockets={'Expr', 'Mapper'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': ct.FlowKind.LazyValueFunc,
'Mapper': ct.FlowKind.LazyValueFunc,
},
input_sockets_optional={'Mapper': True},
)
def compute_data(self, props: dict, input_sockets: dict):
def compute_func(
self, props, input_sockets
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
mapper = input_sockets['Mapper']
has_expr = not ct.FlowSignal.check(expr)
has_mapper = not ct.FlowSignal.check(mapper)
if has_expr and operation is not None:
if not has_mapper:
return expr.compose_within(
operation.jax_func,
supports_jax=True,
)
if operation == MapOperation.UserExpr and has_mapper:
return expr.compose_within(
operation.jax_func_expr(user_expr_func=mapper),
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Array,
output_sockets={'Expr'},
output_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
)
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Expr'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
),
return expr.compose_within(
operation.jax_func,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
####################
# - Compute Auxiliary: Info / Params
# - FlowKind.Info|Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'},
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
)