diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 5700e10..4a4a0a1 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -32,14 +32,14 @@ from ... import base, events log = logger.get(__name__) -X_COMPLEX = sp.Symbol('x', complex=True) - +#################### +# - Operation Enum +#################### class MapOperation(enum.StrEnum): """Valid operations for the `MapMathNode`. Attributes: - UserExpr: Use a user-provided mapping expression. Real: Compute the real part of the input. Imag: Compute the imaginary part of the input. Abs: Compute the absolute value of the input. @@ -67,8 +67,6 @@ class MapOperation(enum.StrEnum): Svd: Compute the SVD-factorized matrices of the input matrix. """ - # By User Expression - UserExpr = enum.auto() # By Number Real = enum.auto() Imag = enum.auto() @@ -106,8 +104,6 @@ class MapOperation(enum.StrEnum): def to_name(value: typ.Self) -> str: MO = MapOperation return { - # By User Expression - MO.UserExpr: '*', # By Number MO.Real: 'ℝ(v)', MO.Imag: 'Im(v)', @@ -159,108 +155,140 @@ class MapOperation(enum.StrEnum): @staticmethod def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: MO = MapOperation - if shape == 'noshape': - return [] - # By Number - if shape is None: - return [ - MO.Real, - MO.Imag, - MO.Abs, - MO.Sq, - MO.Sqrt, - MO.InvSqrt, - MO.Cos, - MO.Sin, - MO.Tan, - MO.Acos, - MO.Asin, - MO.Atan, - MO.Sinc, - ] + match shape: + case 'noshape': + return [] - # By Vector - if len(shape) == 1: - return [ - MO.Norm2, - ] - # By Matrix - if len(shape) == 2: - return [ - MO.Det, - MO.Cond, - MO.NormFro, - MO.Rank, - MO.Diag, - MO.EigVals, - MO.SvdVals, - MO.Inv, - MO.Tra, - MO.Qr, - MO.Chol, - MO.Svd, - ] + # By Number + case None: + return [ + MO.Real, + MO.Imag, + MO.Abs, + MO.Sq, + MO.Sqrt, + MO.InvSqrt, + MO.Cos, + MO.Sin, + MO.Tan, + MO.Acos, + MO.Asin, + MO.Atan, + MO.Sinc, + ] + + match len(shape): + # By Vector + case 1: + return [ + MO.Norm2, + ] + # By Matrix + case 2: + return [ + MO.Det, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.Diag, + MO.EigVals, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Qr, + MO.Chol, + MO.Svd, + ] return [] - def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow): + #################### + # - Function Properties + #################### + @property + def sp_func(self): MO = MapOperation - if self == MO.UserExpr: - return lambda data: user_expr_func.func(data) - - msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`" - raise ValueError(msg) + return { + # By Number + MO.Real: lambda expr: sp.re(expr), + MO.Imag: lambda expr: sp.im(expr), + MO.Abs: lambda expr: sp.Abs(expr), + MO.Sq: lambda expr: expr**2, + MO.Sqrt: lambda expr: sp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr), + MO.Cos: lambda expr: sp.cos(expr), + MO.Sin: lambda expr: sp.sin(expr), + MO.Tan: lambda expr: sp.tan(expr), + MO.Acos: lambda expr: sp.acos(expr), + MO.Asin: lambda expr: sp.asin(expr), + MO.Atan: lambda expr: sp.atan(expr), + MO.Sinc: lambda expr: sp.sinc(expr), + # By Vector + # Vector -> Number + MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr), + # By Matrix + # Matrix -> Number + MO.Det: lambda expr: sp.det(expr), + MO.Cond: lambda expr: expr.condition_number(), + MO.NormFro: lambda expr: expr.norm(ord='fro'), + MO.Rank: lambda expr: expr.rank(), + # Matrix -> Vec + MO.Diag: lambda expr: expr.diagonal(), + MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())), + MO.SvdVals: lambda expr: expr.singular_values(), + # Matrix -> Matrix + MO.Inv: lambda expr: expr.inv(), + MO.Tra: lambda expr: expr.T, + # Matrix -> Matrices + MO.Qr: lambda expr: expr.QRdecomposition(), + MO.Chol: lambda expr: expr.cholesky(), + MO.Svd: lambda expr: expr.singular_value_decomposition(), + }[self] @property def jax_func(self): MO = MapOperation - if self == MO.UserExpr: - msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`" - raise ValueError(msg) - return { # By Number - MO.Real: lambda data: jnp.real(data), - MO.Imag: lambda data: jnp.imag(data), - MO.Abs: lambda data: jnp.abs(data), - MO.Sq: lambda data: jnp.square(data), - MO.Sqrt: lambda data: jnp.sqrt(data), - MO.InvSqrt: lambda data: 1 / jnp.sqrt(data), - MO.Cos: lambda data: jnp.cos(data), - MO.Sin: lambda data: jnp.sin(data), - MO.Tan: lambda data: jnp.tan(data), - MO.Acos: lambda data: jnp.acos(data), - MO.Asin: lambda data: jnp.asin(data), - MO.Atan: lambda data: jnp.atan(data), - MO.Sinc: lambda data: jnp.sinc(data), + MO.Real: lambda expr: jnp.real(expr), + MO.Imag: lambda expr: jnp.imag(expr), + MO.Abs: lambda expr: jnp.abs(expr), + MO.Sq: lambda expr: jnp.square(expr), + MO.Sqrt: lambda expr: jnp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr), + MO.Cos: lambda expr: jnp.cos(expr), + MO.Sin: lambda expr: jnp.sin(expr), + MO.Tan: lambda expr: jnp.tan(expr), + MO.Acos: lambda expr: jnp.acos(expr), + MO.Asin: lambda expr: jnp.asin(expr), + MO.Atan: lambda expr: jnp.atan(expr), + MO.Sinc: lambda expr: jnp.sinc(expr), # By Vector # Vector -> Number - MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1), + MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1), # By Matrix # Matrix -> Number - MO.Det: lambda data: jnp.linalg.det(data), - MO.Cond: lambda data: jnp.linalg.cond(data), - MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'), - MO.Rank: lambda data: jnp.linalg.matrix_rank(data), + MO.Det: lambda expr: jnp.linalg.det(expr), + MO.Cond: lambda expr: jnp.linalg.cond(expr), + MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'), + MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr), # Matrix -> Vec - MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1), - MO.EigVals: lambda data: jnp.linalg.eigvals(data), - MO.SvdVals: lambda data: jnp.linalg.svdvals(data), + MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1), + MO.EigVals: lambda expr: jnp.linalg.eigvals(expr), + MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr), # Matrix -> Matrix - MO.Inv: lambda data: jnp.linalg.inv(data), - MO.Tra: lambda data: jnp.matrix_transpose(data), + MO.Inv: lambda expr: jnp.linalg.inv(expr), + MO.Tra: lambda expr: jnp.matrix_transpose(expr), # Matrix -> Matrices - MO.Qr: lambda data: jnp.linalg.qr(data), - MO.Chol: lambda data: jnp.linalg.cholesky(data), - MO.Svd: lambda data: jnp.linalg.svd(data), + MO.Qr: lambda expr: jnp.linalg.qr(expr), + MO.Chol: lambda expr: jnp.linalg.cholesky(expr), + MO.Svd: lambda expr: jnp.linalg.svd(expr), }[self] def transform_info(self, info: ct.InfoFlow): MO = MapOperation return { - # By User Expression - MO.UserExpr: '*', # By Number MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), @@ -294,9 +322,12 @@ class MapOperation(enum.StrEnum): ), ## TODO: Matrix -> Vec ## TODO: Matrix -> Matrices - }.get(self, info)() + }.get(self, lambda: info)() +#################### +# - Node +#################### class MapMathNode(base.MaxwellSimNode): r"""Applies a function by-structure to the data. @@ -390,11 +421,23 @@ class MapMathNode(base.MaxwellSimNode): #################### # - Properties #################### - operation: MapOperation = bl_cache.BLField( - enum_cb=lambda self, _: self.search_operations() + @events.on_value_changed( + socket_name={'Expr'}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': ct.FlowKind.Info}, + input_sockets_optional={'Expr': True}, ) + def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102 + has_info = not ct.FlowSignal.check(input_sockets['Expr']) - @property + info_pending = ct.FlowSignal.check_single( + input_sockets['Expr'], ct.FlowSignal.FlowPending + ) + + if has_info and not info_pending: + self.expr_output_shape = bl_cache.Signal.InvalidateCache + + @bl_cache.cached_bl_property() def expr_output_shape(self) -> ct.InfoFlow | None: info = self._compute_input('Expr', kind=ct.FlowKind.Info) has_info = not ct.FlowSignal.check(info) @@ -403,6 +446,11 @@ class MapMathNode(base.MaxwellSimNode): return 'noshape' + operation: MapOperation = bl_cache.BLField( + enum_cb=lambda self, _: self.search_operations(), + cb_depends_on={'expr_output_shape'}, + ) + def search_operations(self) -> list[ct.BLEnumElement]: if self.expr_output_shape != 'noshape': return [ @@ -426,110 +474,58 @@ class MapMathNode(base.MaxwellSimNode): layout.prop(self, self.blfields['operation'], text='') #################### - # - Events + # - FlowKind.Value|LazyValueFunc #################### - @events.on_value_changed( - # Trigger - socket_name='Expr', - run_on_init=True, - ) - def on_input_changed(self): - if self.operation not in MapOperation.by_element_shape(self.expr_output_shape): - self.operation = bl_cache.Signal.ResetEnumItems - - @events.on_value_changed( - # Trigger - prop_name={'operation'}, - run_on_init=True, - # Loaded + @events.computes_output_socket( + 'Expr', + kind=ct.FlowKind.Value, props={'operation'}, + input_sockets={'Expr'}, ) - def on_operation_changed(self, props: dict) -> None: + def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal: operation = props['operation'] + expr = input_sockets['Expr'] - # UserExpr: Add/Remove Input Socket - if operation == MapOperation.UserExpr: - current_bl_socket = self.loose_input_sockets.get('Mapper') - if current_bl_socket is None: - self.loose_input_sockets = { - 'Mapper': sockets.ExprSocketDef( - symbols={X_COMPLEX}, - default_value=X_COMPLEX, - mathtype=spux.MathType.Complex, - ), - } + has_expr_value = not ct.FlowSignal.check(expr) - elif self.loose_input_sockets: - self.loose_input_sockets = {} + # Compute Sympy Function + ## -> The operation enum directly provides the appropriate function. + if has_expr_value and operation is not None: + operation.sp_func(expr) + + return ct.Flowsignal.FlowPending - #################### - # - Compute: LazyValueFunc / Array - #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.LazyValueFunc, props={'operation'}, - input_sockets={'Expr', 'Mapper'}, + input_sockets={'Expr'}, input_socket_kinds={ 'Expr': ct.FlowKind.LazyValueFunc, - 'Mapper': ct.FlowKind.LazyValueFunc, }, - input_sockets_optional={'Mapper': True}, ) - def compute_data(self, props: dict, input_sockets: dict): + def compute_func( + self, props, input_sockets + ) -> ct.LazyValueFuncFlow | ct.FlowSignal: operation = props['operation'] expr = input_sockets['Expr'] - mapper = input_sockets['Mapper'] has_expr = not ct.FlowSignal.check(expr) - has_mapper = not ct.FlowSignal.check(mapper) if has_expr and operation is not None: - if not has_mapper: - return expr.compose_within( - operation.jax_func, - supports_jax=True, - ) - if operation == MapOperation.UserExpr and has_mapper: - return expr.compose_within( - operation.jax_func_expr(user_expr_func=mapper), - supports_jax=True, - ) - return ct.FlowSignal.FlowPending - - @events.computes_output_socket( - 'Expr', - kind=ct.FlowKind.Array, - output_sockets={'Expr'}, - output_socket_kinds={ - 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, - }, - unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, - ) - def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow: - lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc] - params = output_sockets['Expr'][ct.FlowKind.Params] - - has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) - has_params = not ct.FlowSignal.check(params) - - if has_lazy_value_func and has_params: - unit_system = unit_systems['BlenderUnits'] - return ct.ArrayFlow( - values=lazy_value_func.func_jax( - *params.scaled_func_args(unit_system), - **params.scaled_func_kwargs(unit_system), - ), + return expr.compose_within( + operation.jax_func, + supports_jax=True, ) return ct.FlowSignal.FlowPending #################### - # - Compute Auxiliary: Info / Params + # - FlowKind.Info|Params #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Info, - props={'active_socket_set', 'operation'}, + props={'operation'}, input_sockets={'Expr'}, input_socket_kinds={'Expr': ct.FlowKind.Info}, )