fix: map node up to current standards
parent
532f0246b5
commit
39747e2d68
|
@ -32,14 +32,14 @@ from ... import base, events
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
X_COMPLEX = sp.Symbol('x', complex=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Operation Enum
|
||||||
|
####################
|
||||||
class MapOperation(enum.StrEnum):
|
class MapOperation(enum.StrEnum):
|
||||||
"""Valid operations for the `MapMathNode`.
|
"""Valid operations for the `MapMathNode`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
UserExpr: Use a user-provided mapping expression.
|
|
||||||
Real: Compute the real part of the input.
|
Real: Compute the real part of the input.
|
||||||
Imag: Compute the imaginary part of the input.
|
Imag: Compute the imaginary part of the input.
|
||||||
Abs: Compute the absolute value of the input.
|
Abs: Compute the absolute value of the input.
|
||||||
|
@ -67,8 +67,6 @@ class MapOperation(enum.StrEnum):
|
||||||
Svd: Compute the SVD-factorized matrices of the input matrix.
|
Svd: Compute the SVD-factorized matrices of the input matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# By User Expression
|
|
||||||
UserExpr = enum.auto()
|
|
||||||
# By Number
|
# By Number
|
||||||
Real = enum.auto()
|
Real = enum.auto()
|
||||||
Imag = enum.auto()
|
Imag = enum.auto()
|
||||||
|
@ -106,8 +104,6 @@ class MapOperation(enum.StrEnum):
|
||||||
def to_name(value: typ.Self) -> str:
|
def to_name(value: typ.Self) -> str:
|
||||||
MO = MapOperation
|
MO = MapOperation
|
||||||
return {
|
return {
|
||||||
# By User Expression
|
|
||||||
MO.UserExpr: '*',
|
|
||||||
# By Number
|
# By Number
|
||||||
MO.Real: 'ℝ(v)',
|
MO.Real: 'ℝ(v)',
|
||||||
MO.Imag: 'Im(v)',
|
MO.Imag: 'Im(v)',
|
||||||
|
@ -159,108 +155,140 @@ class MapOperation(enum.StrEnum):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
|
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
|
||||||
MO = MapOperation
|
MO = MapOperation
|
||||||
if shape == 'noshape':
|
|
||||||
return []
|
|
||||||
|
|
||||||
# By Number
|
match shape:
|
||||||
if shape is None:
|
case 'noshape':
|
||||||
return [
|
return []
|
||||||
MO.Real,
|
|
||||||
MO.Imag,
|
|
||||||
MO.Abs,
|
|
||||||
MO.Sq,
|
|
||||||
MO.Sqrt,
|
|
||||||
MO.InvSqrt,
|
|
||||||
MO.Cos,
|
|
||||||
MO.Sin,
|
|
||||||
MO.Tan,
|
|
||||||
MO.Acos,
|
|
||||||
MO.Asin,
|
|
||||||
MO.Atan,
|
|
||||||
MO.Sinc,
|
|
||||||
]
|
|
||||||
|
|
||||||
# By Vector
|
# By Number
|
||||||
if len(shape) == 1:
|
case None:
|
||||||
return [
|
return [
|
||||||
MO.Norm2,
|
MO.Real,
|
||||||
]
|
MO.Imag,
|
||||||
# By Matrix
|
MO.Abs,
|
||||||
if len(shape) == 2:
|
MO.Sq,
|
||||||
return [
|
MO.Sqrt,
|
||||||
MO.Det,
|
MO.InvSqrt,
|
||||||
MO.Cond,
|
MO.Cos,
|
||||||
MO.NormFro,
|
MO.Sin,
|
||||||
MO.Rank,
|
MO.Tan,
|
||||||
MO.Diag,
|
MO.Acos,
|
||||||
MO.EigVals,
|
MO.Asin,
|
||||||
MO.SvdVals,
|
MO.Atan,
|
||||||
MO.Inv,
|
MO.Sinc,
|
||||||
MO.Tra,
|
]
|
||||||
MO.Qr,
|
|
||||||
MO.Chol,
|
match len(shape):
|
||||||
MO.Svd,
|
# By Vector
|
||||||
]
|
case 1:
|
||||||
|
return [
|
||||||
|
MO.Norm2,
|
||||||
|
]
|
||||||
|
# By Matrix
|
||||||
|
case 2:
|
||||||
|
return [
|
||||||
|
MO.Det,
|
||||||
|
MO.Cond,
|
||||||
|
MO.NormFro,
|
||||||
|
MO.Rank,
|
||||||
|
MO.Diag,
|
||||||
|
MO.EigVals,
|
||||||
|
MO.SvdVals,
|
||||||
|
MO.Inv,
|
||||||
|
MO.Tra,
|
||||||
|
MO.Qr,
|
||||||
|
MO.Chol,
|
||||||
|
MO.Svd,
|
||||||
|
]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def jax_func_expr(self, user_expr_func: ct.LazyValueFuncFlow):
|
####################
|
||||||
|
# - Function Properties
|
||||||
|
####################
|
||||||
|
@property
|
||||||
|
def sp_func(self):
|
||||||
MO = MapOperation
|
MO = MapOperation
|
||||||
if self == MO.UserExpr:
|
return {
|
||||||
return lambda data: user_expr_func.func(data)
|
# By Number
|
||||||
|
MO.Real: lambda expr: sp.re(expr),
|
||||||
msg = "Can't generate JAX function for user-provided expression when MapOperation is not `UserExpr`"
|
MO.Imag: lambda expr: sp.im(expr),
|
||||||
raise ValueError(msg)
|
MO.Abs: lambda expr: sp.Abs(expr),
|
||||||
|
MO.Sq: lambda expr: expr**2,
|
||||||
|
MO.Sqrt: lambda expr: sp.sqrt(expr),
|
||||||
|
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
|
||||||
|
MO.Cos: lambda expr: sp.cos(expr),
|
||||||
|
MO.Sin: lambda expr: sp.sin(expr),
|
||||||
|
MO.Tan: lambda expr: sp.tan(expr),
|
||||||
|
MO.Acos: lambda expr: sp.acos(expr),
|
||||||
|
MO.Asin: lambda expr: sp.asin(expr),
|
||||||
|
MO.Atan: lambda expr: sp.atan(expr),
|
||||||
|
MO.Sinc: lambda expr: sp.sinc(expr),
|
||||||
|
# By Vector
|
||||||
|
# Vector -> Number
|
||||||
|
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr),
|
||||||
|
# By Matrix
|
||||||
|
# Matrix -> Number
|
||||||
|
MO.Det: lambda expr: sp.det(expr),
|
||||||
|
MO.Cond: lambda expr: expr.condition_number(),
|
||||||
|
MO.NormFro: lambda expr: expr.norm(ord='fro'),
|
||||||
|
MO.Rank: lambda expr: expr.rank(),
|
||||||
|
# Matrix -> Vec
|
||||||
|
MO.Diag: lambda expr: expr.diagonal(),
|
||||||
|
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
|
||||||
|
MO.SvdVals: lambda expr: expr.singular_values(),
|
||||||
|
# Matrix -> Matrix
|
||||||
|
MO.Inv: lambda expr: expr.inv(),
|
||||||
|
MO.Tra: lambda expr: expr.T,
|
||||||
|
# Matrix -> Matrices
|
||||||
|
MO.Qr: lambda expr: expr.QRdecomposition(),
|
||||||
|
MO.Chol: lambda expr: expr.cholesky(),
|
||||||
|
MO.Svd: lambda expr: expr.singular_value_decomposition(),
|
||||||
|
}[self]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def jax_func(self):
|
def jax_func(self):
|
||||||
MO = MapOperation
|
MO = MapOperation
|
||||||
if self == MO.UserExpr:
|
|
||||||
msg = "Can't generate JAX function without user-provided expression when MapOperation is `UserExpr`"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
# By Number
|
# By Number
|
||||||
MO.Real: lambda data: jnp.real(data),
|
MO.Real: lambda expr: jnp.real(expr),
|
||||||
MO.Imag: lambda data: jnp.imag(data),
|
MO.Imag: lambda expr: jnp.imag(expr),
|
||||||
MO.Abs: lambda data: jnp.abs(data),
|
MO.Abs: lambda expr: jnp.abs(expr),
|
||||||
MO.Sq: lambda data: jnp.square(data),
|
MO.Sq: lambda expr: jnp.square(expr),
|
||||||
MO.Sqrt: lambda data: jnp.sqrt(data),
|
MO.Sqrt: lambda expr: jnp.sqrt(expr),
|
||||||
MO.InvSqrt: lambda data: 1 / jnp.sqrt(data),
|
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
|
||||||
MO.Cos: lambda data: jnp.cos(data),
|
MO.Cos: lambda expr: jnp.cos(expr),
|
||||||
MO.Sin: lambda data: jnp.sin(data),
|
MO.Sin: lambda expr: jnp.sin(expr),
|
||||||
MO.Tan: lambda data: jnp.tan(data),
|
MO.Tan: lambda expr: jnp.tan(expr),
|
||||||
MO.Acos: lambda data: jnp.acos(data),
|
MO.Acos: lambda expr: jnp.acos(expr),
|
||||||
MO.Asin: lambda data: jnp.asin(data),
|
MO.Asin: lambda expr: jnp.asin(expr),
|
||||||
MO.Atan: lambda data: jnp.atan(data),
|
MO.Atan: lambda expr: jnp.atan(expr),
|
||||||
MO.Sinc: lambda data: jnp.sinc(data),
|
MO.Sinc: lambda expr: jnp.sinc(expr),
|
||||||
# By Vector
|
# By Vector
|
||||||
# Vector -> Number
|
# Vector -> Number
|
||||||
MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
|
MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
|
||||||
# By Matrix
|
# By Matrix
|
||||||
# Matrix -> Number
|
# Matrix -> Number
|
||||||
MO.Det: lambda data: jnp.linalg.det(data),
|
MO.Det: lambda expr: jnp.linalg.det(expr),
|
||||||
MO.Cond: lambda data: jnp.linalg.cond(data),
|
MO.Cond: lambda expr: jnp.linalg.cond(expr),
|
||||||
MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
|
||||||
MO.Rank: lambda data: jnp.linalg.matrix_rank(data),
|
MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
|
||||||
# Matrix -> Vec
|
# Matrix -> Vec
|
||||||
MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
|
MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
|
||||||
MO.EigVals: lambda data: jnp.linalg.eigvals(data),
|
MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
|
||||||
MO.SvdVals: lambda data: jnp.linalg.svdvals(data),
|
MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
|
||||||
# Matrix -> Matrix
|
# Matrix -> Matrix
|
||||||
MO.Inv: lambda data: jnp.linalg.inv(data),
|
MO.Inv: lambda expr: jnp.linalg.inv(expr),
|
||||||
MO.Tra: lambda data: jnp.matrix_transpose(data),
|
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
|
||||||
# Matrix -> Matrices
|
# Matrix -> Matrices
|
||||||
MO.Qr: lambda data: jnp.linalg.qr(data),
|
MO.Qr: lambda expr: jnp.linalg.qr(expr),
|
||||||
MO.Chol: lambda data: jnp.linalg.cholesky(data),
|
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
|
||||||
MO.Svd: lambda data: jnp.linalg.svd(data),
|
MO.Svd: lambda expr: jnp.linalg.svd(expr),
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
def transform_info(self, info: ct.InfoFlow):
|
def transform_info(self, info: ct.InfoFlow):
|
||||||
MO = MapOperation
|
MO = MapOperation
|
||||||
return {
|
return {
|
||||||
# By User Expression
|
|
||||||
MO.UserExpr: '*',
|
|
||||||
# By Number
|
# By Number
|
||||||
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
|
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||||
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
|
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||||
|
@ -294,9 +322,12 @@ class MapOperation(enum.StrEnum):
|
||||||
),
|
),
|
||||||
## TODO: Matrix -> Vec
|
## TODO: Matrix -> Vec
|
||||||
## TODO: Matrix -> Matrices
|
## TODO: Matrix -> Matrices
|
||||||
}.get(self, info)()
|
}.get(self, lambda: info)()
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Node
|
||||||
|
####################
|
||||||
class MapMathNode(base.MaxwellSimNode):
|
class MapMathNode(base.MaxwellSimNode):
|
||||||
r"""Applies a function by-structure to the data.
|
r"""Applies a function by-structure to the data.
|
||||||
|
|
||||||
|
@ -390,11 +421,23 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
operation: MapOperation = bl_cache.BLField(
|
@events.on_value_changed(
|
||||||
enum_cb=lambda self, _: self.search_operations()
|
socket_name={'Expr'},
|
||||||
|
input_sockets={'Expr'},
|
||||||
|
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||||
|
input_sockets_optional={'Expr': True},
|
||||||
)
|
)
|
||||||
|
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||||
|
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||||
|
|
||||||
@property
|
info_pending = ct.FlowSignal.check_single(
|
||||||
|
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_info and not info_pending:
|
||||||
|
self.expr_output_shape = bl_cache.Signal.InvalidateCache
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property()
|
||||||
def expr_output_shape(self) -> ct.InfoFlow | None:
|
def expr_output_shape(self) -> ct.InfoFlow | None:
|
||||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
|
@ -403,6 +446,11 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
return 'noshape'
|
return 'noshape'
|
||||||
|
|
||||||
|
operation: MapOperation = bl_cache.BLField(
|
||||||
|
enum_cb=lambda self, _: self.search_operations(),
|
||||||
|
cb_depends_on={'expr_output_shape'},
|
||||||
|
)
|
||||||
|
|
||||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||||
if self.expr_output_shape != 'noshape':
|
if self.expr_output_shape != 'noshape':
|
||||||
return [
|
return [
|
||||||
|
@ -426,110 +474,58 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
layout.prop(self, self.blfields['operation'], text='')
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Events
|
# - FlowKind.Value|LazyValueFunc
|
||||||
####################
|
####################
|
||||||
@events.on_value_changed(
|
@events.computes_output_socket(
|
||||||
# Trigger
|
'Expr',
|
||||||
socket_name='Expr',
|
kind=ct.FlowKind.Value,
|
||||||
run_on_init=True,
|
|
||||||
)
|
|
||||||
def on_input_changed(self):
|
|
||||||
if self.operation not in MapOperation.by_element_shape(self.expr_output_shape):
|
|
||||||
self.operation = bl_cache.Signal.ResetEnumItems
|
|
||||||
|
|
||||||
@events.on_value_changed(
|
|
||||||
# Trigger
|
|
||||||
prop_name={'operation'},
|
|
||||||
run_on_init=True,
|
|
||||||
# Loaded
|
|
||||||
props={'operation'},
|
props={'operation'},
|
||||||
|
input_sockets={'Expr'},
|
||||||
)
|
)
|
||||||
def on_operation_changed(self, props: dict) -> None:
|
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
|
expr = input_sockets['Expr']
|
||||||
|
|
||||||
# UserExpr: Add/Remove Input Socket
|
has_expr_value = not ct.FlowSignal.check(expr)
|
||||||
if operation == MapOperation.UserExpr:
|
|
||||||
current_bl_socket = self.loose_input_sockets.get('Mapper')
|
|
||||||
if current_bl_socket is None:
|
|
||||||
self.loose_input_sockets = {
|
|
||||||
'Mapper': sockets.ExprSocketDef(
|
|
||||||
symbols={X_COMPLEX},
|
|
||||||
default_value=X_COMPLEX,
|
|
||||||
mathtype=spux.MathType.Complex,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
elif self.loose_input_sockets:
|
# Compute Sympy Function
|
||||||
self.loose_input_sockets = {}
|
## -> The operation enum directly provides the appropriate function.
|
||||||
|
if has_expr_value and operation is not None:
|
||||||
|
operation.sp_func(expr)
|
||||||
|
|
||||||
|
return ct.Flowsignal.FlowPending
|
||||||
|
|
||||||
####################
|
|
||||||
# - Compute: LazyValueFunc / Array
|
|
||||||
####################
|
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.LazyValueFunc,
|
kind=ct.FlowKind.LazyValueFunc,
|
||||||
props={'operation'},
|
props={'operation'},
|
||||||
input_sockets={'Expr', 'Mapper'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr': ct.FlowKind.LazyValueFunc,
|
'Expr': ct.FlowKind.LazyValueFunc,
|
||||||
'Mapper': ct.FlowKind.LazyValueFunc,
|
|
||||||
},
|
},
|
||||||
input_sockets_optional={'Mapper': True},
|
|
||||||
)
|
)
|
||||||
def compute_data(self, props: dict, input_sockets: dict):
|
def compute_func(
|
||||||
|
self, props, input_sockets
|
||||||
|
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
expr = input_sockets['Expr']
|
expr = input_sockets['Expr']
|
||||||
mapper = input_sockets['Mapper']
|
|
||||||
|
|
||||||
has_expr = not ct.FlowSignal.check(expr)
|
has_expr = not ct.FlowSignal.check(expr)
|
||||||
has_mapper = not ct.FlowSignal.check(mapper)
|
|
||||||
|
|
||||||
if has_expr and operation is not None:
|
if has_expr and operation is not None:
|
||||||
if not has_mapper:
|
return expr.compose_within(
|
||||||
return expr.compose_within(
|
operation.jax_func,
|
||||||
operation.jax_func,
|
supports_jax=True,
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
if operation == MapOperation.UserExpr and has_mapper:
|
|
||||||
return expr.compose_within(
|
|
||||||
operation.jax_func_expr(user_expr_func=mapper),
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
|
|
||||||
@events.computes_output_socket(
|
|
||||||
'Expr',
|
|
||||||
kind=ct.FlowKind.Array,
|
|
||||||
output_sockets={'Expr'},
|
|
||||||
output_socket_kinds={
|
|
||||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
|
||||||
},
|
|
||||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
|
||||||
)
|
|
||||||
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
|
|
||||||
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
|
||||||
params = output_sockets['Expr'][ct.FlowKind.Params]
|
|
||||||
|
|
||||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
|
||||||
has_params = not ct.FlowSignal.check(params)
|
|
||||||
|
|
||||||
if has_lazy_value_func and has_params:
|
|
||||||
unit_system = unit_systems['BlenderUnits']
|
|
||||||
return ct.ArrayFlow(
|
|
||||||
values=lazy_value_func.func_jax(
|
|
||||||
*params.scaled_func_args(unit_system),
|
|
||||||
**params.scaled_func_kwargs(unit_system),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Compute Auxiliary: Info / Params
|
# - FlowKind.Info|Params
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Info,
|
kind=ct.FlowKind.Info,
|
||||||
props={'active_socket_set', 'operation'},
|
props={'operation'},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue