fix: map node up to current standards

main
Sofus Albert Høgsbro Rose 2024-05-18 18:42:54 +02:00
parent 532f0246b5
commit 39747e2d68
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
1 changed files with 160 additions and 164 deletions

View File

@ -32,14 +32,14 @@ from ... import base, events
log = logger.get(__name__) log = logger.get(__name__)
X_COMPLEX = sp.Symbol('x', complex=True)
####################
# - Operation Enum
####################
class MapOperation(enum.StrEnum): class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`. """Valid operations for the `MapMathNode`.
Attributes: Attributes:
UserExpr: Use a user-provided mapping expression.
Real: Compute the real part of the input. Real: Compute the real part of the input.
Imag: Compute the imaginary part of the input. Imag: Compute the imaginary part of the input.
Abs: Compute the absolute value of the input. Abs: Compute the absolute value of the input.
@ -67,8 +67,6 @@ class MapOperation(enum.StrEnum):
Svd: Compute the SVD-factorized matrices of the input matrix. Svd: Compute the SVD-factorized matrices of the input matrix.
""" """
# By User Expression
UserExpr = enum.auto()
# By Number # By Number
Real = enum.auto() Real = enum.auto()
Imag = enum.auto() Imag = enum.auto()
@ -106,8 +104,6 @@ class MapOperation(enum.StrEnum):
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
MO = MapOperation MO = MapOperation
return { return {
# By User Expression
MO.UserExpr: '*',
# By Number # By Number
MO.Real: '(v)', MO.Real: '(v)',
MO.Imag: 'Im(v)', MO.Imag: 'Im(v)',
@ -159,108 +155,140 @@ class MapOperation(enum.StrEnum):
@staticmethod @staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
MO = MapOperation MO = MapOperation
if shape == 'noshape':
return []
# By Number match shape:
if shape is None: case 'noshape':
return [ return []
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
# By Vector # By Number
if len(shape) == 1: case None:
return [ return [
MO.Norm2, MO.Real,
] MO.Imag,
# By Matrix MO.Abs,
if len(shape) == 2: MO.Sq,
return [ MO.Sqrt,
MO.Det, MO.InvSqrt,
MO.Cond, MO.Cos,
MO.NormFro, MO.Sin,
MO.Rank, MO.Tan,
MO.Diag, MO.Acos,
MO.EigVals, MO.Asin,
MO.SvdVals, MO.Atan,
MO.Inv, MO.Sinc,
MO.Tra, ]
MO.Qr,
MO.Chol, match len(shape):
MO.Svd, # By Vector
] case 1:
return [
MO.Norm2,
]
# By Matrix
case 2:
return [
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
return [] return []
def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow): ####################
# - Function Properties
####################
@property
def sp_func(self):
MO = MapOperation MO = MapOperation
if self == MO.UserExpr: return {
return lambda data: user_expr_func.func(data) # By Number
MO.Real: lambda expr: sp.re(expr),
msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`" MO.Imag: lambda expr: sp.im(expr),
raise ValueError(msg) MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
MO.Tan: lambda expr: sp.tan(expr),
MO.Acos: lambda expr: sp.acos(expr),
MO.Asin: lambda expr: sp.asin(expr),
MO.Atan: lambda expr: sp.atan(expr),
MO.Sinc: lambda expr: sp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr),
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: sp.det(expr),
MO.Cond: lambda expr: expr.condition_number(),
MO.NormFro: lambda expr: expr.norm(ord='fro'),
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
MO.Qr: lambda expr: expr.QRdecomposition(),
MO.Chol: lambda expr: expr.cholesky(),
MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property @property
def jax_func(self): def jax_func(self):
MO = MapOperation MO = MapOperation
if self == MO.UserExpr:
msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`"
raise ValueError(msg)
return { return {
# By Number # By Number
MO.Real: lambda data: jnp.real(data), MO.Real: lambda expr: jnp.real(expr),
MO.Imag: lambda data: jnp.imag(data), MO.Imag: lambda expr: jnp.imag(expr),
MO.Abs: lambda data: jnp.abs(data), MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda data: jnp.square(data), MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda data: jnp.sqrt(data), MO.Sqrt: lambda expr: jnp.sqrt(expr),
MO.InvSqrt: lambda data: 1 / jnp.sqrt(data), MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda data: jnp.cos(data), MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda data: jnp.sin(data), MO.Sin: lambda expr: jnp.sin(expr),
MO.Tan: lambda data: jnp.tan(data), MO.Tan: lambda expr: jnp.tan(expr),
MO.Acos: lambda data: jnp.acos(data), MO.Acos: lambda expr: jnp.acos(expr),
MO.Asin: lambda data: jnp.asin(data), MO.Asin: lambda expr: jnp.asin(expr),
MO.Atan: lambda data: jnp.atan(data), MO.Atan: lambda expr: jnp.atan(expr),
MO.Sinc: lambda data: jnp.sinc(data), MO.Sinc: lambda expr: jnp.sinc(expr),
# By Vector # By Vector
# Vector -> Number # Vector -> Number
MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1), MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
# By Matrix # By Matrix
# Matrix -> Number # Matrix -> Number
MO.Det: lambda data: jnp.linalg.det(data), MO.Det: lambda expr: jnp.linalg.det(expr),
MO.Cond: lambda data: jnp.linalg.cond(data), MO.Cond: lambda expr: jnp.linalg.cond(expr),
MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'), MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
MO.Rank: lambda data: jnp.linalg.matrix_rank(data), MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
# Matrix -> Vec # Matrix -> Vec
MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1), MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
MO.EigVals: lambda data: jnp.linalg.eigvals(data), MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
MO.SvdVals: lambda data: jnp.linalg.svdvals(data), MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
# Matrix -> Matrix # Matrix -> Matrix
MO.Inv: lambda data: jnp.linalg.inv(data), MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda data: jnp.matrix_transpose(data), MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices # Matrix -> Matrices
MO.Qr: lambda data: jnp.linalg.qr(data), MO.Qr: lambda expr: jnp.linalg.qr(expr),
MO.Chol: lambda data: jnp.linalg.cholesky(data), MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
MO.Svd: lambda data: jnp.linalg.svd(data), MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self] }[self]
def transform_info(self, info: ct.InfoFlow): def transform_info(self, info: ct.InfoFlow):
MO = MapOperation MO = MapOperation
return { return {
# By User Expression
MO.UserExpr: '*',
# By Number # By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
@ -294,9 +322,12 @@ class MapOperation(enum.StrEnum):
), ),
## TODO: Matrix -> Vec ## TODO: Matrix -> Vec
## TODO: Matrix -> Matrices ## TODO: Matrix -> Matrices
}.get(self, info)() }.get(self, lambda: info)()
####################
# - Node
####################
class MapMathNode(base.MaxwellSimNode): class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data. r"""Applies a function by-structure to the data.
@ -390,11 +421,23 @@ class MapMathNode(base.MaxwellSimNode):
#################### ####################
# - Properties # - Properties
#################### ####################
operation: MapOperation = bl_cache.BLField( @events.on_value_changed(
enum_cb=lambda self, _: self.search_operations() socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
) )
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
@property info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
if has_info and not info_pending:
self.expr_output_shape = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None: def expr_output_shape(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info) info = self._compute_input('Expr', kind=ct.FlowKind.Info)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
@ -403,6 +446,11 @@ class MapMathNode(base.MaxwellSimNode):
return 'noshape' return 'noshape'
operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_output_shape'},
)
def search_operations(self) -> list[ct.BLEnumElement]: def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape != 'noshape': if self.expr_output_shape != 'noshape':
return [ return [
@ -426,110 +474,58 @@ class MapMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - Events # - FlowKind.Value|LazyValueFunc
#################### ####################
@events.on_value_changed( @events.computes_output_socket(
# Trigger 'Expr',
socket_name='Expr', kind=ct.FlowKind.Value,
run_on_init=True,
)
def on_input_changed(self):
if self.operation not in MapOperation.by_element_shape(self.expr_output_shape):
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
prop_name={'operation'},
run_on_init=True,
# Loaded
props={'operation'}, props={'operation'},
input_sockets={'Expr'},
) )
def on_operation_changed(self, props: dict) -> None: def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
operation = props['operation'] operation = props['operation']
expr = input_sockets['Expr']
# UserExpr: Add/Remove Input Socket has_expr_value = not ct.FlowSignal.check(expr)
if operation == MapOperation.UserExpr:
current_bl_socket = self.loose_input_sockets.get('Mapper')
if current_bl_socket is None:
self.loose_input_sockets = {
'Mapper': sockets.ExprSocketDef(
symbols={X_COMPLEX},
default_value=X_COMPLEX,
mathtype=spux.MathType.Complex,
),
}
elif self.loose_input_sockets: # Compute Sympy Function
self.loose_input_sockets = {} ## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None:
operation.sp_func(expr)
return ct.Flowsignal.FlowPending
####################
# - Compute: LazyValueFunc / Array
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
props={'operation'}, props={'operation'},
input_sockets={'Expr', 'Mapper'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': ct.FlowKind.LazyValueFunc, 'Expr': ct.FlowKind.LazyValueFunc,
'Mapper': ct.FlowKind.LazyValueFunc,
}, },
input_sockets_optional={'Mapper': True},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_func(
self, props, input_sockets
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
operation = props['operation'] operation = props['operation']
expr = input_sockets['Expr'] expr = input_sockets['Expr']
mapper = input_sockets['Mapper']
has_expr = not ct.FlowSignal.check(expr) has_expr = not ct.FlowSignal.check(expr)
has_mapper = not ct.FlowSignal.check(mapper)
if has_expr and operation is not None: if has_expr and operation is not None:
if not has_mapper: return expr.compose_within(
return expr.compose_within( operation.jax_func,
operation.jax_func, supports_jax=True,
supports_jax=True,
)
if operation == MapOperation.UserExpr and has_mapper:
return expr.compose_within(
operation.jax_func_expr(user_expr_func=mapper),
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Array,
output_sockets={'Expr'},
output_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
)
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Expr'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
),
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - Compute Auxiliary: Info / Params # - FlowKind.Info|Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'}, props={'operation'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
) )