refactor: Huge simplifications from ExprSocket

main
Sofus Albert Høgsbro Rose 2024-04-30 18:42:46 +02:00
parent 80d7b21c34
commit e330b9a451
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
43 changed files with 2919 additions and 2463 deletions

10
TODO.md
View File

@ -527,9 +527,17 @@ Reported:
- (SOLVED) <https://projects.blender.org/blender/blender/issues/119664> - (SOLVED) <https://projects.blender.org/blender/blender/issues/119664>
Unreported: Unreported:
- Units are unruly, and are entirely useless when it comes to going small like this.
- The `__mp_main__` bug. - The `__mp_main__` bug.
- Animated properties within custom node trees don't update with the frame. See: <https://projects.blender.org/blender/blender/issues/66392> - Animated properties within custom node trees don't update with the frame. See: <https://projects.blender.org/blender/blender/issues/66392>
- Can't update `items` using `id_propertie_ui` of `EnumProperty` - Can't update `items` using `id_properties_ui` of `EnumProperty`. Maybe less a bug than an annoyance.
- **Matrix Display Bug**: The data given to matrix properties is entirely ignored in the UI; the data is flattened, then left-to-right, up-to-down, the data is inserted. It's neither row-major nor column-major - it's completely flat.
- Though, if one wanted row-major (**as is consistent with `mathutils.Matrix`**), one would be disappointed - the UI prints the matrix property column-major
- Trying to set the matrix property with a `mathutils.Matrix` is even stranger - firstly, the size of the `mathutils.Matrix` must be transposed with respect to the property size (again the col/row major mismatch). But secondly, even when accounting for the col/row major mismatch, the values of a ex. 2x3 (row-major) matrix (written to with a 3x2 matrix with same flattened sequence) is written in a very strange order:
- Write `mathutils.Matrix` `[[0,1], [2,3], [4,10]]`: Results in (UI displayed row-major) `[[0,3], [4,1], [3,5]]`
- **Workaround (write)**: Simply flatten the 2D array, re-shape by `[cols,rows]`. The UI will display as the original array. `myarray.flatten().reshape([cols,rows])`.
- **Workaround (read)**: `np.array([[el1 for el1 in el0] for el0 in BLENDER_OBJ.matrix_prop]).flatten().reshape([rows,cols])`. Simply flatten the property read 2D array and re-shape by `[rows,cols]`. Mind that data type out is equal to data type in.
- Also, for bool matrices, `toggle=True` has no effect. `alignment='CENTER'` also doesn't align the checkboxes in their cells.
## Tidy3D bugs ## Tidy3D bugs
Unreported: Unreported:

View File

@ -118,11 +118,9 @@ quartodoc:
- subtitle: "`bl_maxwell.utils`" - subtitle: "`bl_maxwell.utils`"
desc: Utilities wo/shared global state. desc: Utilities wo/shared global state.
contents: contents:
- utils.analyze_geonodes
- utils.blender_type_enum - utils.blender_type_enum
- utils.extra_sympy_units - utils.extra_sympy_units
- utils.logger - utils.logger
- utils.pydantic_sympy
- subtitle: "`bl_maxwell.services`" - subtitle: "`bl_maxwell.services`"
desc: Utilities w/shared global state. desc: Utilities w/shared global state.
@ -172,7 +170,6 @@ quartodoc:
- socket_colors - socket_colors
- bl_socket_types - bl_socket_types
- bl_socket_desc_map - bl_socket_desc_map
- socket_units
- unit_systems - unit_systems

View File

@ -5,11 +5,9 @@ Attributes:
BL_SOCKET_4D_TYPE_PREFIXES: Blender socket prefixes which indicate that the Blender socket has four values. BL_SOCKET_4D_TYPE_PREFIXES: Blender socket prefixes which indicate that the Blender socket has four values.
""" """
import functools
import typing as typ import typing as typ
import bpy import bpy
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger as _logger from blender_maxwell.utils import logger as _logger
@ -19,238 +17,54 @@ from . import sockets
log = _logger.get(__name__) log = _logger.get(__name__)
BLSocketType: typ.TypeAlias = str ## A Blender-Defined Socket Type
BLSocketValue: typ.TypeAlias = typ.Any ## A Blender Socket Value
BLSocketSize: typ.TypeAlias = int
DescType: typ.TypeAlias = str
Unit: typ.TypeAlias = typ.Any ## Type of a valid unit
## TODO: Move this kind of thing to contracts
#################### ####################
# - BL Socket Size Parser # - Blender -> Socket Def(s)
#################### ####################
BL_SOCKET_3D_TYPE_PREFIXES = { def socket_def_from_bl_isocket(
'NodeSocketVector', bl_isocket: bpy.types.NodeTreeInterfaceSocket,
'NodeSocketRotation', ) -> sockets.base.SocketDef | None:
"""Deduces and constructs an appropriate SocketDef to match the given `bl_interface_socket`."""
blsck_info = ct.BLSocketType.info_from_bl_isocket(bl_isocket)
if blsck_info.has_support and not blsck_info.is_preview:
# Map Expr Socket
## -> Accounts for any combo of shape/MathType/PhysicalType.
if blsck_info.socket_type == ct.SocketType.Expr:
return sockets.ExprSocketDef(
shape=blsck_info.size.shape,
mathtype=blsck_info.mathtype,
physical_type=blsck_info.physical_type,
default_unit=ct.UNITS_BLENDER[blsck_info.physical_type],
default_value=blsck_info.default_value,
)
## TODO: Explicitly map default to other supported SocketDef constructors
return sockets.SOCKET_DEFS[blsck_info.socket_type]()
return None
def sockets_from_geonodes(
geonodes: bpy.types.GeometryNodeTree,
) -> dict[ct.SocketName, sockets.base.SocketDef]:
"""Deduces and constructs appropriate SocketDefs to match all input sockets to the given GeoNodes tree."""
raw_socket_defs = {
socket_name: socket_def_from_bl_isocket(bl_isocket)
for socket_name, bl_isocket in geonodes.interface.items_tree.items()
} }
BL_SOCKET_4D_TYPE_PREFIXES = { return {
'NodeSocketColor', socket_name: socket_def
for socket_name, socket_def in raw_socket_defs.items()
if socket_def is not None
} }
@functools.lru_cache(maxsize=4096) ## TODO: Make it fast, it's in a hot loop...
def _size_from_bl_socket( def info_from_geonodes(
description: str, geonodes: bpy.types.GeometryNodeTree,
bl_socket_type: BLSocketType, ) -> dict[ct.SocketName, ct.BLSocketInfo]:
): """Deduces and constructs appropriate SocketDefs to match all input sockets to the given GeoNodes tree."""
"""Parses the number of elements contained in a Blender interface socket. return {
socket_name: ct.BLSocketType.info_from_bl_isocket(bl_isocket)
Since there are no 2D sockets in Blender, the user can specify "2D" in the Blender socket's description to "promise" that only the first two values will be used. for socket_name, bl_isocket in geonodes.interface.items_tree.items()
When this is done, the third value is just never altered by the addon. }
A hard-coded set of NodeSocket<Type> prefixes are used to determine which interface sockets are, in fact, 3D.
- For 3D sockets, a hard-coded list of Blender node socket types is used.
- Else, it is a 1D socket type.
"""
if description.startswith('2D'):
return 2
if any(
bl_socket_type.startswith(bl_socket_3d_type_prefix)
for bl_socket_3d_type_prefix in BL_SOCKET_3D_TYPE_PREFIXES
):
return 3
if any(
bl_socket_type.startswith(bl_socket_4d_type_prefix)
for bl_socket_4d_type_prefix in BL_SOCKET_4D_TYPE_PREFIXES
):
return 4
return 1
####################
# - BL Socket Type / Unit Parser
####################
@functools.lru_cache(maxsize=4096)
def _socket_type_from_bl_socket(
description: str,
bl_socket_type: BLSocketType,
) -> ct.SocketType:
"""Parse a Blender socket for a matching BLMaxwell socket type, relying on both the Blender socket type and user-generated hints in the description.
Arguments:
description: The description from Blender socket, aka. `bl_socket.description`.
bl_socket_type: The Blender socket type, aka. `bl_socket.socket_type`.
Returns:
The type of a MaxwellSimSocket that corresponds to the Blender socket.
"""
size = _size_from_bl_socket(description, bl_socket_type)
# Determine Socket Type Directly
## The naive mapping from BL socket -> Maxwell socket may be good enough.
if (
direct_socket_type := ct.BL_SOCKET_DIRECT_TYPE_MAP.get((bl_socket_type, size))
) is None:
msg = "Blender interface socket has no mapping among 'MaxwellSimSocket's."
raise ValueError(msg)
# (No Description) Return Direct Socket Type
if ct.BL_SOCKET_DESCR_ANNOT_STRING not in description:
return direct_socket_type
# Parse Description for Socket Type
## The "2D" token is special; don't include it if it's there.
descr_params = description.split(ct.BL_SOCKET_DESCR_ANNOT_STRING)[0]
directive = (
_tokens[0] if (_tokens := descr_params.split(' '))[0] != '2D' else _tokens[1]
)
if directive == 'Preview':
return direct_socket_type ## TODO: Preview element handling
if (
socket_type := ct.BL_SOCKET_DESCR_TYPE_MAP.get(
(directive, bl_socket_type, size)
)
) is None:
msg = f'Socket description "{(directive, bl_socket_type, size)}" doesn\'t map to a socket type + unit'
raise ValueError(msg)
return socket_type
####################
# - BL Socket Interface Definition
####################
@functools.lru_cache(maxsize=4096)
def _socket_def_from_bl_socket(
description: str,
bl_socket_type: BLSocketType,
) -> ct.SocketType:
return sockets.SOCKET_DEFS[_socket_type_from_bl_socket(description, bl_socket_type)]
def socket_def_from_bl_socket(
bl_interface_socket: bpy.types.NodeTreeInterfaceSocket,
) -> sockets.base.SocketDef:
"""Computes an appropriate (no-arg) SocketDef from the given `bl_interface_socket`, by parsing it."""
return _socket_def_from_bl_socket(
bl_interface_socket.description, bl_interface_socket.bl_socket_idname
)
####################
# - Extract Default Interface Socket Value
####################
def _read_bl_socket_default_value(
description: str,
bl_socket_type: BLSocketType,
bl_socket_value: BLSocketValue,
unit_system: dict | None = None,
allow_unit_not_in_unit_system: bool = False,
) -> typ.Any:
# Parse the BL Socket Type and Value
## The 'lambda' delays construction until size is determined.
socket_type = _socket_type_from_bl_socket(description, bl_socket_type)
parsed_socket_value = {
1: lambda: bl_socket_value,
2: lambda: sp.Matrix(tuple(bl_socket_value)[:2]),
3: lambda: sp.Matrix(tuple(bl_socket_value)),
4: lambda: sp.Matrix(tuple(bl_socket_value)),
}[_size_from_bl_socket(description, bl_socket_type)]()
# Add Unit-System Unit to Parsed
## Use the matching socket type to lookup the unit in the unit system.
if unit_system is not None:
if (unit := unit_system.get(socket_type)) is None:
if allow_unit_not_in_unit_system:
return parsed_socket_value
msg = f'Unit system does not provide a unit for {socket_type}'
raise RuntimeError(msg)
return parsed_socket_value * unit
return parsed_socket_value
def read_bl_socket_default_value(
bl_interface_socket: bpy.types.NodeTreeInterfaceSocket,
unit_system: dict | None = None,
allow_unit_not_in_unit_system: bool = False,
) -> typ.Any:
"""Reads the `default_value` of a Blender socket, guaranteeing a well-formed value consistent with the passed unit system.
Arguments:
bl_interface_socket: The Blender interface socket to analyze for description, socket type, and default value.
unit_system: The mapping from BLMaxwell SocketType to corresponding unit, used to apply the appropriate unit to the output.
Returns:
The parsed, well-formed version of `bl_socket.default_value`, of the appropriate form and unit.
"""
return _read_bl_socket_default_value(
bl_interface_socket.description,
bl_interface_socket.bl_socket_idname,
bl_interface_socket.default_value,
unit_system=unit_system,
allow_unit_not_in_unit_system=allow_unit_not_in_unit_system,
)
def _writable_bl_socket_value(
description: str,
bl_socket_type: BLSocketType,
value: typ.Any,
unit_system: dict | None = None,
allow_unit_not_in_unit_system: bool = False,
) -> typ.Any:
socket_type = _socket_type_from_bl_socket(description, bl_socket_type)
# Retrieve Unit-System Unit
if unit_system is not None:
if (unit := unit_system.get(socket_type)) is None:
if allow_unit_not_in_unit_system:
_bl_socket_value = value
else:
msg = f'Unit system does not provide a unit for {socket_type}'
raise RuntimeError(msg)
else:
_bl_socket_value = spux.scale_to_unit(value, unit)
else:
_bl_socket_value = value
# Compute Blender Socket Value
if isinstance(_bl_socket_value, sp.Basic | sp.MatrixBase):
bl_socket_value = spux.sympy_to_python(_bl_socket_value)
else:
bl_socket_value = _bl_socket_value
if _size_from_bl_socket(description, bl_socket_type) == 2: # noqa: PLR2004
bl_socket_value = bl_socket_value[:2]
return bl_socket_value
def writable_bl_socket_value(
bl_interface_socket: bpy.types.NodeTreeInterfaceSocket,
value: typ.Any,
unit_system: dict | None = None,
allow_unit_not_in_unit_system: bool = False,
) -> typ.Any:
"""Processes a value to be ready-to-write to a Blender socket.
Arguments:
bl_interface_socket: The Blender interface socket to analyze
value: The value to prepare for writing to the given Blender socket.
unit_system: The mapping from BLMaxwell SocketType to corresponding unit, used to scale the value to the the appropriate unit.
Returns:
A value corresponding to the input, which is guaranteed to be compatible with the Blender socket (incl. via a GeoNodes modifier), as well as correctly scaled with respect to the given unit system.
"""
return _writable_bl_socket_value(
bl_interface_socket.description,
bl_interface_socket.bl_socket_idname,
value,
unit_system=unit_system,
allow_unit_not_in_unit_system=allow_unit_not_in_unit_system,
)

View File

@ -22,8 +22,7 @@ from blender_maxwell.contracts import (
addon, addon,
) )
from .bl_socket_desc_map import BL_SOCKET_DESCR_ANNOT_STRING, BL_SOCKET_DESCR_TYPE_MAP from .bl_socket_types import BLSocketInfo, BLSocketType
from .bl_socket_types import BL_SOCKET_DIRECT_TYPE_MAP
from .category_labels import NODE_CAT_LABELS from .category_labels import NODE_CAT_LABELS
from .category_types import NodeCategory from .category_types import NodeCategory
from .flow_events import FlowEvent from .flow_events import FlowEvent
@ -43,7 +42,6 @@ from .mobj_types import ManagedObjType
from .node_types import NodeType from .node_types import NodeType
from .socket_colors import SOCKET_COLORS from .socket_colors import SOCKET_COLORS
from .socket_types import SocketType from .socket_types import SocketType
from .socket_units import SOCKET_UNITS, unit_to_socket_type
from .tree_types import TreeType from .tree_types import TreeType
from .unit_systems import UNITS_BLENDER, UNITS_TIDY3D from .unit_systems import UNITS_BLENDER, UNITS_TIDY3D
@ -72,15 +70,12 @@ __all__ = [
'Icon', 'Icon',
'TreeType', 'TreeType',
'SocketType', 'SocketType',
'SOCKET_UNITS',
'unit_to_socket_type',
'SOCKET_COLORS', 'SOCKET_COLORS',
'SOCKET_SHAPES', 'SOCKET_SHAPES',
'UNITS_BLENDER', 'UNITS_BLENDER',
'UNITS_TIDY3D', 'UNITS_TIDY3D',
'BL_SOCKET_DESCR_TYPE_MAP', 'BLSocketInfo',
'BL_SOCKET_DIRECT_TYPE_MAP', 'BLSocketType',
'BL_SOCKET_DESCR_ANNOT_STRING',
'NodeType', 'NodeType',
'NodeCategory', 'NodeCategory',
'NODE_CAT_LABELS', 'NODE_CAT_LABELS',

View File

@ -1,62 +0,0 @@
from .socket_types import SocketType as ST
BL_SOCKET_DESCR_ANNOT_STRING = ':: '
BL_SOCKET_DESCR_TYPE_MAP = {
('Time', 'NodeSocketFloat', 1): ST.PhysicalTime,
('Angle', 'NodeSocketFloat', 1): ST.PhysicalAngle,
('SolidAngle', 'NodeSocketFloat', 1): ST.PhysicalSolidAngle,
('Rotation', 'NodeSocketVector', 2): ST.PhysicalRot2D,
('Rotation', 'NodeSocketVector', 3): ST.PhysicalRot3D,
('Freq', 'NodeSocketFloat', 1): ST.PhysicalFreq,
('AngFreq', 'NodeSocketFloat', 1): ST.PhysicalAngFreq,
## Cartesian
('Length', 'NodeSocketFloat', 1): ST.PhysicalLength,
('Area', 'NodeSocketFloat', 1): ST.PhysicalArea,
('Volume', 'NodeSocketFloat', 1): ST.PhysicalVolume,
('Disp', 'NodeSocketVector', 2): ST.PhysicalDisp2D,
('Disp', 'NodeSocketVector', 3): ST.PhysicalDisp3D,
('Point', 'NodeSocketFloat', 1): ST.PhysicalPoint1D,
('Point', 'NodeSocketVector', 2): ST.PhysicalPoint2D,
('Point', 'NodeSocketVector', 3): ST.PhysicalPoint3D,
('Size', 'NodeSocketVector', 2): ST.PhysicalSize2D,
('Size', 'NodeSocketVector', 3): ST.PhysicalSize3D,
## Mechanical
('Mass', 'NodeSocketFloat', 1): ST.PhysicalMass,
('Speed', 'NodeSocketFloat', 1): ST.PhysicalSpeed,
('Vel', 'NodeSocketVector', 2): ST.PhysicalVel2D,
('Vel', 'NodeSocketVector', 3): ST.PhysicalVel3D,
('Accel', 'NodeSocketFloat', 1): ST.PhysicalAccelScalar,
('Accel', 'NodeSocketVector', 2): ST.PhysicalAccel2D,
('Accel', 'NodeSocketVector', 3): ST.PhysicalAccel3D,
('Force', 'NodeSocketFloat', 1): ST.PhysicalForceScalar,
('Force', 'NodeSocketVector', 2): ST.PhysicalForce2D,
('Force', 'NodeSocketVector', 3): ST.PhysicalForce3D,
('Pressure', 'NodeSocketFloat', 1): ST.PhysicalPressure,
## Energetic
('Energy', 'NodeSocketFloat', 1): ST.PhysicalEnergy,
('Power', 'NodeSocketFloat', 1): ST.PhysicalPower,
('Temp', 'NodeSocketFloat', 1): ST.PhysicalTemp,
## ELectrodynamical
('Curr', 'NodeSocketFloat', 1): ST.PhysicalCurr,
('CurrDens', 'NodeSocketVector', 2): ST.PhysicalCurrDens2D,
('CurrDens', 'NodeSocketVector', 3): ST.PhysicalCurrDens3D,
('Charge', 'NodeSocketFloat', 1): ST.PhysicalCharge,
('Voltage', 'NodeSocketFloat', 1): ST.PhysicalVoltage,
('Capacitance', 'NodeSocketFloat', 1): ST.PhysicalCapacitance,
('Resistance', 'NodeSocketFloat', 1): ST.PhysicalResistance,
('Conductance', 'NodeSocketFloat', 1): ST.PhysicalConductance,
('MagFlux', 'NodeSocketFloat', 1): ST.PhysicalMagFlux,
('MagFluxDens', 'NodeSocketFloat', 1): ST.PhysicalMagFluxDens,
('Inductance', 'NodeSocketFloat', 1): ST.PhysicalInductance,
('EField', 'NodeSocketFloat', 2): ST.PhysicalEField3D,
('EField', 'NodeSocketFloat', 3): ST.PhysicalEField2D,
('HField', 'NodeSocketFloat', 2): ST.PhysicalHField3D,
('HField', 'NodeSocketFloat', 3): ST.PhysicalHField2D,
## Luminal
('LumIntensity', 'NodeSocketFloat', 1): ST.PhysicalLumIntensity,
('LumFlux', 'NodeSocketFloat', 1): ST.PhysicalLumFlux,
('Illuminance', 'NodeSocketFloat', 1): ST.PhysicalIlluminance,
## Optical
('PolJones', 'NodeSocketFloat', 2): ST.PhysicalPolJones,
('Pol', 'NodeSocketFloat', 4): ST.PhysicalPol,
}

View File

@ -1,39 +1,312 @@
from .socket_types import SocketType as ST import dataclasses
import enum
import typing as typ
BL_SOCKET_DIRECT_TYPE_MAP = { import bpy
import sympy as sp
from blender_maxwell.utils import blender_type_enum
from blender_maxwell.utils import extra_sympy_units as spux
from .socket_types import SocketType
BL_SOCKET_DESCR_ANNOT_STRING = ':: '
@dataclasses.dataclass(kw_only=True, frozen=True)
class BLSocketInfo:
has_support: bool
is_preview: bool
socket_type: SocketType | None
size: spux.NumberSize1D | None
physical_type: spux.PhysicalType | None
default_value: spux.ScalarUnitlessRealExpr
bl_isocket_identifier: spux.ScalarUnitlessRealExpr
@blender_type_enum.prefix_values_with('NodeSocket')
class BLSocketType(enum.StrEnum):
Virtual = 'Virtual'
# Blender # Blender
('NodeSocketCollection', 1): ST.BlenderCollection, Image = 'Image'
('NodeSocketImage', 1): ST.BlenderImage, Shader = 'Shader'
('NodeSocketObject', 1): ST.BlenderObject, Material = 'Material'
('NodeSocketMaterial', 1): ST.BlenderMaterial, Geometry = 'Material'
Object = 'Object'
Collection = 'Collection'
# Basic # Basic
('NodeSocketString', 1): ST.String, Bool = 'Bool'
('NodeSocketBool', 1): ST.Bool, String = 'String'
Menu = 'Menu'
# Float # Float
('NodeSocketFloat', 1): ST.RealNumber, Float = 'Float'
# ("NodeSocketFloatAngle", 1): ST.PhysicalAngle, FloatUnsigned = 'FloatUnsigned'
# ("NodeSocketFloatDistance", 1): ST.PhysicalLength, FloatAngle = 'FloatAngle'
('NodeSocketFloatFactor', 1): ST.RealNumber, FloatDistance = 'FloatDistance'
('NodeSocketFloatPercentage', 1): ST.RealNumber, FloatFactor = 'FloatFactor'
# ("NodeSocketFloatTime", 1): ST.PhysicalTime, FloatPercentage = 'FloatPercentage'
# ("NodeSocketFloatTimeAbsolute", 1): ST.PhysicalTime, FloatTime = 'FloatTime'
FloatTimeAbsolute = 'FloatTimeAbsolute'
# Int # Int
('NodeSocketInt', 1): ST.IntegerNumber, Int = 'Int'
('NodeSocketIntFactor', 1): ST.IntegerNumber, IntFactor = 'IntFactor'
('NodeSocketIntPercentage', 1): ST.IntegerNumber, IntPercentage = 'IntPercentage'
('NodeSocketIntUnsigned', 1): ST.IntegerNumber, IntUnsigned = 'IntUnsigned'
# Vector
Color = 'Color'
Rotation = 'Rotation'
Vector = 'Vector'
VectorAcceleration = 'Acceleration'
VectorDirection = 'Direction'
VectorEuler = 'Euler'
VectorTranslation = 'Translation'
VectorVelocity = 'Velocity'
VectorXYZ = 'XYZ'
@staticmethod
def from_bl_isocket(
bl_isocket: bpy.types.NodeTreeInterfaceSocket,
) -> typ.Self:
return BLSocketType[bl_isocket.bl_socket_idname]
@staticmethod
def info_from_bl_isocket(
bl_isocket: bpy.types.NodeTreeInterfaceSocket,
) -> typ.Self:
return BLSocketType.from_bl_isocket(bl_isocket).parse(
bl_isocket.default_value, bl_isocket.description, bl_isocket.identifier
)
####################
# - Direct Properties
####################
@property
def has_support(self) -> bool:
BLST = BLSocketType
return {
BLST.Virtual: False,
BLST.Geometry: False,
BLST.Shader: False,
BLST.FloatUnsigned: False,
BLST.IntUnsigned: False,
}.get(self, True)
@property
def socket_type(self) -> SocketType | None:
"""Deduce `SocketType` corresponding to the Blender socket type.
**The socket type alone is not enough** to actually create the socket.
To declare a socket in the addon, an appropriate `socket.SocketDef` must be constructed in a manner that respects contextual information.
Returns:
The corresponding socket type, if the addon has support for mapping this Blender socket.
For sockets with support, the fallback is always `SocketType.Expr`.
Support is determined using `self.has_support`
"""
if not self.has_support:
return None
BLST = BLSocketType
ST = SocketType
return {
# Blender
# Basic
BLST.Bool: ST.String,
# Float
# Array-Like # Array-Like
('NodeSocketColor', 3): ST.Color, BLST.Color: ST.Color,
('NodeSocketRotation', 2): ST.PhysicalRot2D, }.get(self, ST.Expr)
('NodeSocketVector', 2): ST.Real2DVector,
('NodeSocketVector', 3): ST.Real3DVector, @property
# ("NodeSocketVectorAcceleration", 2): ST.PhysicalAccel2D, def mathtype(self) -> spux.MathType | None:
# ("NodeSocketVectorAcceleration", 3): ST.PhysicalAccel3D, """Deduce `spux.MathType` corresponding to the Blender socket type.
# ("NodeSocketVectorDirection", 2): ST.Real2DVectorDir,
# ("NodeSocketVectorDirection", 3): ST.Real3DVectorDir, **The socket type alone is not enough** to actually create the socket.
('NodeSocketVectorEuler', 2): ST.PhysicalRot2D, To declare a socket in the addon, an appropriate `socket.SocketDef` must be constructed in a manner that respects contextual information.
('NodeSocketVectorEuler', 3): ST.PhysicalRot3D,
# ("NodeSocketVectorTranslation", 3): ST.PhysicalDisp3D, Returns:
# ("NodeSocketVectorVelocity", 3): ST.PhysicalVel3D, The corresponding socket type, if the addon has support for mapping this Blender socket.
# ("NodeSocketVectorXYZ", 3): ST.PhysicalPoint3D, For sockets with support, the fallback is always `SocketType.Expr`.
}
Support is determined using `self.has_support`
"""
if not self.has_support:
return None
BLST = BLSocketType
MT = spux.MathType
return {
# Blender
# Basic
BLST.Bool: MT.Bool,
# Float
BLST.Float: MT.Real,
BLST.FloatAngle: MT.Real,
BLST.FloatDistance: MT.Real,
BLST.FloatFactor: MT.Real,
BLST.FloatPercentage: MT.Real,
BLST.FloatTime: MT.Real,
BLST.FloatTimeAbsolute: MT.Real,
# Int
BLST.Int: MT.Integer,
BLST.IntFactor: MT.Integer,
BLST.IntPercentage: MT.Integer,
# Vector
BLST.Color: MT.Real,
BLST.Rotation: MT.Real,
BLST.Vector: MT.Real,
BLST.VectorAcceleration: MT.Real,
BLST.VectorDirection: MT.Real,
BLST.VectorEuler: MT.Real,
BLST.VectorTranslation: MT.Real,
BLST.VectorVelocity: MT.Real,
BLST.VectorXYZ: MT.Real,
}.get(self)
@property
def size(
self,
) -> (
typ.Literal[
spux.NumberSize1D.Scalar, spux.NumberSize1D.Vec3, spux.NumberSize1D.Vec4
]
| None
):
"""Deduce the internal size of the Blender socket's data.
Returns:
A `spux.NumberSize1D` reflecting the internal data representation.
Always falls back to `{spux.NumberSize1D.Scalar}`.
"""
if not self.has_support:
return None
S = spux.NumberSize1D
BLST = BLSocketType
return {
BLST.Color: S.Vec4,
BLST.Rotation: S.Vec3,
BLST.VectorAcceleration: S.Vec3,
BLST.VectorDirection: S.Vec3,
BLST.VectorEuler: S.Vec3,
BLST.VectorTranslation: S.Vec3,
BLST.VectorVelocity: S.Vec3,
BLST.VectorXYZ: S.Vec3,
}.get(self, {S.Scalar})
@property
def unambiguous_physical_type(self) -> spux.PhysicalType | None:
"""Deduce an **unambiguous** physical type from the Blender socket, if any.
Blender does have its own unit systems, which leads to some Blender socket subtypes having an obvious choice of physical unit dimension (ex. `BLSocketType.FloatTime`).
In such cases, the `spux.PhysicalType` that matches the Blender socket can be uniquely determined.
When a phsyical type cannot be immediately determined in this way, other mechanisms must be used to deduce what to do.
Returns:
A physical type corresponding to the Blender socket, **if exactly one** can be determined with no ambiguity - else `None`.
If more than one physical type might apply, `None`.
"""
if not self.has_support:
return None
P = spux.PhysicalType
BLST = BLSocketType
{
BLST.FloatAngle: P.Angle,
BLST.FloatDistance: P.Length,
BLST.FloatTime: P.Time,
BLST.FloatTimeAbsolute: P.Time, ## What's the difference?
BLST.VectorAcceleration: P.Accel,
## BLST.VectorDirection: Directions are unitless (within cartesian)
BLST.VectorEuler: P.Angle,
BLST.VectorTranslation: P.Length,
BLST.VectorVelocity: P.Vel,
BLST.VectorXYZ: P.Length,
}.get(self)
@property
def valid_sizes(self) -> set[spux.NumberSize1D] | None:
"""Deduce which sizes it would be valid to interpret a Blender socket as having.
This property's purpose is merely to present a set of options that _are valid_.
Whether an a size _is truly usable_, can only be determined with contextual information, wherein certain decisions can be made:
- **2D vs. 3D**: In general, Blender's vector socket types are **only** 3D, and we cannot ever _really_ have a 2D vector.
If one chooses interpret ex. `BLSocketType.Vector` as 2D, one might do so by pretending the third coordinate doesn't exist.
But **this is a subjective decision**, which always has to align with the logic on the other side of the Blender socket.
- **Colors**: Generally, `BLSocketType.Color` is always 4D, representing `RGBA` (with alpha channel).
However, we often don't care about alpha; therefore, we might choose to "just" push a 3D `RGB` vector.
Again, **this is a subjective decision** which requires one to make a decision about alpha, for example "alpha is always 1".
- **Scalars**: We can generally always interpret a scalar as a vector, using well-defined "broadcasting".
Returns:
The set of `spux.NumberSize1D`s, which it would be safe to interpret the Blender socket as having.
Always falls back to `{spux.NumberSize1D.Scalar}`.
"""
if not self.has_support:
return None
S = spux.NumberSize1D
BLST = BLSocketType
return {
BLST.Color: {S.Scalar, S.Vec3, S.Vec4},
BLST.Rotation: {S.Vec2, S.Vec3},
BLST.VectorAcceleration: {S.Scalar, S.Vec2, S.Vec3},
BLST.VectorDirection: {S.Scalar, S.Vec2, S.Vec3},
BLST.VectorEuler: {S.Vec2, S.Vec3},
BLST.VectorTranslation: {S.Scalar, S.Vec2, S.Vec3},
BLST.VectorVelocity: {S.Scalar, S.Vec2, S.Vec3},
BLST.VectorXYZ: {S.Scalar, S.Vec2, S.Vec3},
}.get(self, {S.Scalar})
####################
# - Parsed Properties
####################
def parse(
self, bl_default_value: typ.Any, description: str, bl_isocket_identifier: str
) -> BLSocketInfo:
# Parse the Description
## TODO: Some kind of error on invalid parse if there is also no unambiguous physical type
descr_params = description.split(BL_SOCKET_DESCR_ANNOT_STRING)[0]
directive = (
_tokens[0]
if (_tokens := descr_params.split(' '))[0] != '2D'
else _tokens[1]
)
## Interpret the Description Parse
parsed_physical_type = getattr(spux.PhysicalType, directive, None)
physical_type = (
self.unambiguous_physical_type
if self.unambiguous_physical_type is not None
else parsed_physical_type
)
# Parse the Default Value
if self.mathtype is not None:
if self.size == spux.NumberSize1D.Scalar:
default_value = self.mathtype.pytype(bl_default_value)
elif description.startswith('2D'):
default_value = sp.Matrix(tuple(bl_default_value)[:2])
else:
default_value = sp.Matrix(tuple(bl_default_value))
else:
default_value = bl_default_value
# Return Parsed Socket Information
## -> Combining directly known and parsed knowledge.
## -> Should contain everything needed to match the Blender socket.
return BLSocketInfo(
has_support=self.has_support,
is_preview=(directive == 'Preview'),
socket_type=self.socket_type,
size=self.size,
physical_type=physical_type,
default_value=default_value,
bl_isocket_identifier=bl_isocket_identifier,
)

View File

@ -316,8 +316,12 @@ class LazyValueFuncFlow:
""" """
func: LazyFunction func: LazyFunction
func_args: list[type] = dataclasses.field(default_factory=list) func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
func_kwargs: dict[str, type] = dataclasses.field(default_factory=dict) default_factory=list
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=dict
)
supports_jax: bool = False supports_jax: bool = False
supports_numba: bool = False supports_numba: bool = False
@ -432,9 +436,7 @@ class LazyArrayRangeFlow:
unit: The unit of the generated array values unit: The unit of the generated array values
int_symbols: Set of integer-valued variables from which `start` and/or `stop` are determined. symbols: Set of variables from which `start` and/or `stop` are determined.
real_symbols: Set of real-valued variables from which `start` and/or `stop` are determined.
complex_symbols: Set of complex-valued variables from which `start` and/or `stop` are determined.
""" """
start: spux.ScalarUnitlessComplexExpr start: spux.ScalarUnitlessComplexExpr
@ -444,12 +446,10 @@ class LazyArrayRangeFlow:
unit: spux.Unit | None = None unit: spux.Unit | None = None
int_symbols: set[spux.IntSymbol] = frozenset() symbols: frozenset[spux.IntSymbol] = frozenset()
real_symbols: set[spux.RealSymbol] = frozenset()
complex_symbols: set[spux.ComplexSymbol] = frozenset()
@functools.cached_property @functools.cached_property
def symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
The order is guaranteed to be **deterministic**. The order is guaranteed to be **deterministic**.
@ -457,10 +457,7 @@ class LazyArrayRangeFlow:
Returns: Returns:
All symbols valid for use in the expression. All symbols valid for use in the expression.
""" """
return sorted( return sorted(self.symbols, key=lambda sym: sym.name)
self.int_symbols | self.real_symbols | self.complex_symbols,
key=lambda sym: sym.name,
)
@functools.cached_property @functools.cached_property
def mathtype(self) -> spux.MathType: def mathtype(self) -> spux.MathType:
@ -508,9 +505,7 @@ class LazyArrayRangeFlow:
steps=self.steps, steps=self.steps,
scaling=self.scaling, scaling=self.scaling,
unit=corrected_unit, unit=corrected_unit,
int_symbols=self.int_symbols, symbols=self.symbols,
real_symbols=self.real_symbols,
complex_symbols=self.complex_symbols,
) )
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"' msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
@ -530,15 +525,12 @@ class LazyArrayRangeFlow:
""" """
if self.unit is not None: if self.unit is not None:
return LazyArrayRangeFlow( return LazyArrayRangeFlow(
start=spu.convert_to(self.start, unit), start=spu.scale_to_unit(self.start * self.unit, unit),
stop=spu.convert_to(self.stop, unit), stop=spu.scale_to_unit(self.stop * self.unit, unit),
steps=self.steps, steps=self.steps,
scaling=self.scaling, scaling=self.scaling,
unit=unit, unit=unit,
symbols=self.symbols, symbols=self.symbols,
int_symbols=self.int_symbols,
real_symbols=self.real_symbols,
complex_symbols=self.complex_symbols,
) )
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
@ -549,7 +541,7 @@ class LazyArrayRangeFlow:
#################### ####################
def rescale_bounds( def rescale_bounds(
self, self,
scaler: typ.Callable[ rescale_func: typ.Callable[
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr [spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
], ],
reverse: bool = False, reverse: bool = False,
@ -570,18 +562,12 @@ class LazyArrayRangeFlow:
A rescaled `LazyArrayRangeFlow`. A rescaled `LazyArrayRangeFlow`.
""" """
return LazyArrayRangeFlow( return LazyArrayRangeFlow(
start=spu.convert_to( start=rescale_func(self.start if not reverse else self.stop),
scaler(self.start if not reverse else self.stop), self.unit stop=rescale_func(self.stop if not reverse else self.start),
),
stop=spu.convert_to(
scaler(self.stop if not reverse else self.start), self.unit
),
steps=self.steps, steps=self.steps,
scaling=self.scaling, scaling=self.scaling,
unit=self.unit, unit=self.unit,
int_symbols=self.int_symbols, symbols=self.symbols,
real_symbols=self.real_symbols,
complex_symbols=self.complex_symbols,
) )
#################### ####################
@ -650,9 +636,7 @@ class LazyArrayRangeFlow:
""" """
return LazyValueFuncFlow( return LazyValueFuncFlow(
func=self.as_func, func=self.as_func,
func_args=[ func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
(sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols
],
supports_jax=True, supports_jax=True,
) )
@ -709,13 +693,38 @@ class LazyArrayRangeFlow:
#################### ####################
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow: class ParamsFlow:
func_args: list[typ.Any] = dataclasses.field(default_factory=list) func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
####################
# - Scaled Func Args
####################
def scaled_func_args(self, unit_system: spux.UnitSystem):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
return [
spux.convert_to_unit_system(func_arg, unit_system, use_jax_array=True)
for func_arg in self.func_args
]
def scaled_func_kwargs(self, unit_system: spux.UnitSystem):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
return {
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
for arg_name, arg in self.func_args
}
####################
# - Operations
####################
def __or__( def __or__(
self, self,
other: typ.Self, other: typ.Self,
): ):
"""Combine two function parameter lists, such that the LHS will be concatenated with the RHS.
Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters.
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
"""
return ParamsFlow( return ParamsFlow(
func_args=self.func_args + other.func_args, func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs, func_kwargs=self.func_kwargs | other.func_kwargs,
@ -723,10 +732,8 @@ class ParamsFlow:
def compose_within( def compose_within(
self, self,
enclosing_func_args: list[tuple[type]] = (), enclosing_func_args: list[spux.SympyExpr] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
enclosing_func_arg_units: dict[str, type] = MappingProxyType({}),
enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}),
) -> typ.Self: ) -> typ.Self:
return ParamsFlow( return ParamsFlow(
func_args=self.func_args + list(enclosing_func_args), func_args=self.func_args + list(enclosing_func_args),
@ -745,6 +752,8 @@ class InfoFlow:
default_factory=dict default_factory=dict
) ## TODO: Rename to dim_idxs ) ## TODO: Rename to dim_idxs
## TODO: Add PhysicalType
@functools.cached_property @functools.cached_property
def dim_lens(self) -> dict[str, int]: def dim_lens(self) -> dict[str, int]:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@ -769,12 +778,14 @@ class InfoFlow:
] ]
# Output Information # Output Information
## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list) output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None) output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field() output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field() output_unit: spux.Unit | None = dataclasses.field()
# Pinned Dimension Information # Pinned Dimension Information
## TODO: Add PhysicalType
pinned_dim_names: list[str] = dataclasses.field(default_factory=list) pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field( pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict default_factory=dict

View File

@ -4,42 +4,13 @@ from .socket_types import SocketType as ST
SOCKET_COLORS = { SOCKET_COLORS = {
# Basic # Basic
ST.Any: (0.9, 0.9, 0.9, 1.0), # Light Grey ST.Any: (0.9, 0.9, 0.9, 1.0), # Light Grey
ST.Data: (0.8, 0.8, 0.8, 1.0), # Light Grey
ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey
ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey
ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey
# Number
ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue
ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue
ST.RealNumber: (0.3, 0.3, 0.8, 1.0), # Medium Blue
ST.ComplexNumber: (0.2, 0.2, 0.7, 1.0), # Dark Blue
# Vector
ST.Integer2DVector: (0.5, 1.0, 0.5, 1.0), # Light Green
ST.Real2DVector: (0.5, 1.0, 0.5, 1.0), # Light Green
ST.Complex2DVector: (0.4, 0.9, 0.4, 1.0), # Medium Light Green
ST.Integer3DVector: (0.3, 0.8, 0.3, 1.0), # Medium Green
ST.Real3DVector: (0.3, 0.8, 0.3, 1.0), # Medium Green
ST.Complex3DVector: (0.2, 0.7, 0.2, 1.0), # Dark Green
# Physical # Physical
ST.PhysicalUnitSystem: (1.0, 0.5, 0.5, 1.0), # Light Red ST.PhysicalUnitSystem: (1.0, 0.5, 0.5, 1.0), # Light Red
ST.PhysicalTime: (1.0, 0.5, 0.5, 1.0), # Light Red
ST.PhysicalAngle: (0.9, 0.45, 0.45, 1.0), # Medium Light Red
ST.PhysicalLength: (0.8, 0.4, 0.4, 1.0), # Medium Red
ST.PhysicalArea: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red
ST.PhysicalVolume: (0.6, 0.3, 0.3, 1.0), # Dark Red
ST.PhysicalPoint2D: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red
ST.PhysicalPoint3D: (0.6, 0.3, 0.3, 1.0), # Dark Red
ST.PhysicalSize2D: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red
ST.PhysicalSize3D: (0.6, 0.3, 0.3, 1.0), # Dark Red
ST.PhysicalMass: (0.9, 0.6, 0.4, 1.0), # Light Orange
ST.PhysicalSpeed: (0.8, 0.55, 0.35, 1.0), # Medium Light Orange
ST.PhysicalAccelScalar: (0.7, 0.5, 0.3, 1.0), # Medium Orange
ST.PhysicalForceScalar: (0.6, 0.45, 0.25, 1.0), # Medium Dark Orange
ST.PhysicalAccel3D: (0.7, 0.5, 0.3, 1.0), # Medium Orange
ST.PhysicalForce3D: (0.6, 0.45, 0.25, 1.0), # Medium Dark Orange
ST.PhysicalPol: (0.5, 0.4, 0.2, 1.0), # Dark Orange ST.PhysicalPol: (0.5, 0.4, 0.2, 1.0), # Dark Orange
ST.PhysicalFreq: (1.0, 0.7, 0.5, 1.0), # Light Peach
# Blender # Blender
ST.BlenderMaterial: (0.8, 0.6, 1.0, 1.0), # Lighter Purple ST.BlenderMaterial: (0.8, 0.6, 1.0, 1.0), # Lighter Purple
ST.BlenderObject: (0.7, 0.5, 1.0, 1.0), # Light Purple ST.BlenderObject: (0.7, 0.5, 1.0, 1.0), # Light Purple
@ -56,6 +27,7 @@ SOCKET_COLORS = {
ST.MaxwellBoundConds: (0.9, 0.8, 0.5, 1.0), # Light Gold ST.MaxwellBoundConds: (0.9, 0.8, 0.5, 1.0), # Light Gold
ST.MaxwellBoundCond: (0.8, 0.7, 0.45, 1.0), # Medium Light Gold ST.MaxwellBoundCond: (0.8, 0.7, 0.45, 1.0), # Medium Light Gold
ST.MaxwellMonitor: (0.7, 0.6, 0.4, 1.0), # Medium Gold ST.MaxwellMonitor: (0.7, 0.6, 0.4, 1.0), # Medium Gold
ST.MaxwellMonitorData: (0.7, 0.6, 0.4, 1.0), # Medium Gold
ST.MaxwellFDTDSim: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold ST.MaxwellFDTDSim: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold
ST.MaxwellFDTDSimData: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold ST.MaxwellFDTDSimData: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold
ST.MaxwellSimGrid: (0.5, 0.4, 0.3, 1.0), # Dark Gold ST.MaxwellSimGrid: (0.5, 0.4, 0.3, 1.0), # Dark Gold

View File

@ -5,31 +5,14 @@ from blender_maxwell.utils import blender_type_enum
@blender_type_enum.append_cls_name_to_values @blender_type_enum.append_cls_name_to_values
class SocketType(blender_type_enum.BlenderTypeEnum): class SocketType(blender_type_enum.BlenderTypeEnum):
Expr = enum.auto()
# Base # Base
Any = enum.auto() Any = enum.auto()
Data = enum.auto()
Bool = enum.auto() Bool = enum.auto()
String = enum.auto() String = enum.auto()
FilePath = enum.auto() FilePath = enum.auto()
Color = enum.auto() Color = enum.auto()
Expr = enum.auto()
# Number
IntegerNumber = enum.auto()
RationalNumber = enum.auto()
RealNumber = enum.auto()
ComplexNumber = enum.auto()
# Vector
Integer2DVector = enum.auto()
Real2DVector = enum.auto()
Real2DVectorDir = enum.auto()
Complex2DVector = enum.auto()
Integer3DVector = enum.auto()
Real3DVector = enum.auto()
Real3DVectorDir = enum.auto()
Complex3DVector = enum.auto()
# Blender # Blender
BlenderMaterial = enum.auto() BlenderMaterial = enum.auto()
@ -53,6 +36,7 @@ class SocketType(blender_type_enum.BlenderTypeEnum):
MaxwellStructure = enum.auto() MaxwellStructure = enum.auto()
MaxwellMonitor = enum.auto() MaxwellMonitor = enum.auto()
MaxwellMonitorData = enum.auto()
MaxwellFDTDSim = enum.auto() MaxwellFDTDSim = enum.auto()
MaxwellFDTDSimData = enum.auto() MaxwellFDTDSimData = enum.auto()
@ -66,76 +50,5 @@ class SocketType(blender_type_enum.BlenderTypeEnum):
# Physical # Physical
PhysicalUnitSystem = enum.auto() PhysicalUnitSystem = enum.auto()
PhysicalTime = enum.auto()
PhysicalAngle = enum.auto()
PhysicalSolidAngle = enum.auto()
PhysicalRot2D = enum.auto()
PhysicalRot3D = enum.auto()
PhysicalFreq = enum.auto()
PhysicalAngFreq = enum.auto()
## Cartesian
PhysicalLength = enum.auto()
PhysicalArea = enum.auto()
PhysicalVolume = enum.auto()
PhysicalDisp2D = enum.auto()
PhysicalDisp3D = enum.auto()
PhysicalPoint1D = enum.auto()
PhysicalPoint2D = enum.auto()
PhysicalPoint3D = enum.auto()
PhysicalSize2D = enum.auto()
PhysicalSize3D = enum.auto()
## Mechanical
PhysicalMass = enum.auto()
PhysicalSpeed = enum.auto()
PhysicalVel2D = enum.auto()
PhysicalVel3D = enum.auto()
PhysicalAccelScalar = enum.auto()
PhysicalAccel2D = enum.auto()
PhysicalAccel3D = enum.auto()
PhysicalForceScalar = enum.auto()
PhysicalForce2D = enum.auto()
PhysicalForce3D = enum.auto()
PhysicalPressure = enum.auto()
## Energetic
PhysicalEnergy = enum.auto()
PhysicalPower = enum.auto()
PhysicalTemp = enum.auto()
## Electrodynamical
PhysicalCurr = enum.auto()
PhysicalCurrDens2D = enum.auto()
PhysicalCurrDens3D = enum.auto()
PhysicalCharge = enum.auto()
PhysicalVoltage = enum.auto()
PhysicalCapacitance = enum.auto()
PhysicalResistance = enum.auto()
PhysicalConductance = enum.auto()
PhysicalMagFlux = enum.auto()
PhysicalMagFluxDens = enum.auto()
PhysicalInductance = enum.auto()
PhysicalEField2D = enum.auto()
PhysicalEField3D = enum.auto()
PhysicalHField2D = enum.auto()
PhysicalHField3D = enum.auto()
## Luminal
PhysicalLumIntensity = enum.auto()
PhysicalLumFlux = enum.auto()
PhysicalIlluminance = enum.auto()
## Optical ## Optical
PhysicalPolJones = enum.auto()
PhysicalPol = enum.auto() PhysicalPol = enum.auto()

View File

@ -1,287 +0,0 @@
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from .socket_types import SocketType as ST # noqa: N817
SOCKET_UNITS = {
ST.PhysicalTime: {
'default': 'PS',
'values': {
'FS': spux.femtosecond,
'PS': spu.picosecond,
'NS': spu.nanosecond,
'MS': spu.microsecond,
'MLSEC': spu.millisecond,
'SEC': spu.second,
'MIN': spu.minute,
'HOUR': spu.hour,
'DAY': spu.day,
},
},
ST.PhysicalAngle: {
'default': 'RADIAN',
'values': {
'RADIAN': spu.radian,
'DEGREE': spu.degree,
'STERAD': spu.steradian,
'ANGMIL': spu.angular_mil,
},
},
ST.PhysicalLength: {
'default': 'UM',
'values': {
'PM': spu.picometer,
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
'INCH': spu.inch,
'FOOT': spu.foot,
'YARD': spu.yard,
'MILE': spu.mile,
},
},
ST.PhysicalArea: {
'default': 'UM_SQ',
'values': {
'PM_SQ': spu.picometer**2,
'A_SQ': spu.angstrom**2,
'NM_SQ': spu.nanometer**2,
'UM_SQ': spu.micrometer**2,
'MM_SQ': spu.millimeter**2,
'CM_SQ': spu.centimeter**2,
'M_SQ': spu.meter**2,
'INCH_SQ': spu.inch**2,
'FOOT_SQ': spu.foot**2,
'YARD_SQ': spu.yard**2,
'MILE_SQ': spu.mile**2,
},
},
ST.PhysicalVolume: {
'default': 'UM_CB',
'values': {
'PM_CB': spu.picometer**3,
'A_CB': spu.angstrom**3,
'NM_CB': spu.nanometer**3,
'UM_CB': spu.micrometer**3,
'MM_CB': spu.millimeter**3,
'CM_CB': spu.centimeter**3,
'M_CB': spu.meter**3,
'ML': spu.milliliter,
'L': spu.liter,
'INCH_CB': spu.inch**3,
'FOOT_CB': spu.foot**3,
'YARD_CB': spu.yard**3,
'MILE_CB': spu.mile**3,
},
},
ST.PhysicalPoint2D: {
'default': 'UM',
'values': {
'PM': spu.picometer,
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
'INCH': spu.inch,
'FOOT': spu.foot,
'YARD': spu.yard,
'MILE': spu.mile,
},
},
ST.PhysicalPoint3D: {
'default': 'UM',
'values': {
'PM': spu.picometer,
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
'INCH': spu.inch,
'FOOT': spu.foot,
'YARD': spu.yard,
'MILE': spu.mile,
},
},
ST.PhysicalSize2D: {
'default': 'UM',
'values': {
'PM': spu.picometer,
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
'INCH': spu.inch,
'FOOT': spu.foot,
'YARD': spu.yard,
'MILE': spu.mile,
},
},
ST.PhysicalSize3D: {
'default': 'UM',
'values': {
'PM': spu.picometer,
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
'INCH': spu.inch,
'FOOT': spu.foot,
'YARD': spu.yard,
'MILE': spu.mile,
},
},
ST.PhysicalMass: {
'default': 'UG',
'values': {
'E_REST': spu.electron_rest_mass,
'DAL': spu.dalton,
'UG': spu.microgram,
'MG': spu.milligram,
'G': spu.gram,
'KG': spu.kilogram,
'TON': spu.metric_ton,
},
},
ST.PhysicalSpeed: {
'default': 'UM_S',
'values': {
'PM_S': spu.picometer / spu.second,
'NM_S': spu.nanometer / spu.second,
'UM_S': spu.micrometer / spu.second,
'MM_S': spu.millimeter / spu.second,
'M_S': spu.meter / spu.second,
'KM_S': spu.kilometer / spu.second,
'KM_H': spu.kilometer / spu.hour,
'FT_S': spu.feet / spu.second,
'MI_H': spu.mile / spu.hour,
},
},
ST.PhysicalAccelScalar: {
'default': 'UM_S_SQ',
'values': {
'PM_S_SQ': spu.picometer / spu.second**2,
'NM_S_SQ': spu.nanometer / spu.second**2,
'UM_S_SQ': spu.micrometer / spu.second**2,
'MM_S_SQ': spu.millimeter / spu.second**2,
'M_S_SQ': spu.meter / spu.second**2,
'KM_S_SQ': spu.kilometer / spu.second**2,
'FT_S_SQ': spu.feet / spu.second**2,
},
},
ST.PhysicalForceScalar: {
'default': 'UNEWT',
'values': {
'KG_M_S_SQ': spu.kg * spu.m / spu.second**2,
'NNEWT': spux.nanonewton,
'UNEWT': spux.micronewton,
'MNEWT': spux.millinewton,
'NEWT': spu.newton,
},
},
ST.PhysicalAccel3D: {
'default': 'UM_S_SQ',
'values': {
'PM_S_SQ': spu.picometer / spu.second**2,
'NM_S_SQ': spu.nanometer / spu.second**2,
'UM_S_SQ': spu.micrometer / spu.second**2,
'MM_S_SQ': spu.millimeter / spu.second**2,
'M_S_SQ': spu.meter / spu.second**2,
'KM_S_SQ': spu.kilometer / spu.second**2,
'FT_S_SQ': spu.feet / spu.second**2,
},
},
ST.PhysicalForce3D: {
'default': 'UNEWT',
'values': {
'KG_M_S_SQ': spu.kg * spu.m / spu.second**2,
'NNEWT': spux.nanonewton,
'UNEWT': spux.micronewton,
'MNEWT': spux.millinewton,
'NEWT': spu.newton,
},
},
ST.PhysicalFreq: {
'default': 'THZ',
'values': {
'HZ': spu.hertz,
'KHZ': spux.kilohertz,
'MHZ': spux.megahertz,
'GHZ': spux.gigahertz,
'THZ': spux.terahertz,
'PHZ': spux.petahertz,
'EHZ': spux.exahertz,
},
},
ST.PhysicalPol: {
'default': 'RADIAN',
'values': {
'RADIAN': spu.radian,
'DEGREE': spu.degree,
'STERAD': spu.steradian,
'ANGMIL': spu.angular_mil,
},
},
ST.MaxwellMedium: {
'default': 'NM',
'values': {
'PM': spu.picometer, ## c(vac) = wl*freq
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
},
},
ST.MaxwellMonitor: {
'default': 'NM',
'values': {
'PM': spu.picometer, ## c(vac) = wl*freq
'A': spu.angstrom,
'NM': spu.nanometer,
'UM': spu.micrometer,
'MM': spu.millimeter,
'CM': spu.centimeter,
'M': spu.meter,
},
},
}
def unit_to_socket_type(
unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None
) -> ST:
"""Returns a SocketType that accepts the given unit.
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
This isn't super clean, but it's good enough for our needs right now.
Returns:
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
"""
if unit is None and fallback_mathtype is not None:
return {
spux.MathType.Integer: ST.IntegerNumber,
spux.MathType.Rational: ST.RationalNumber,
spux.MathType.Real: ST.RealNumber,
spux.MathType.Complex: ST.ComplexNumber,
}[fallback_mathtype]
for socket_type, _units in SOCKET_UNITS.items():
if unit in _units['values'].values():
return socket_type
msg = f"Unit {unit} doesn't have an obvious SocketType."
raise ValueError(msg)

View File

@ -1,59 +1,69 @@
"""Specifies unit systems for use in the node tree.
Attributes:
UNITS_BLENDER: A unit system that serves as a reasonable default for the a 3D workspace that interprets the results of electromagnetic simulations.
**NOTE**: The addon _specifically_ neglects to respect Blender's builtin units.
In testing, Blender's system was found to be extremely brittle when "going small" like this; in particular, `picosecond`-order time units were impossible to specify functionally.
UNITS_TIDY3D: A unit system that aligns with Tidy3D's simulator.
See <https://docs.flexcompute.com/projects/tidy3d/en/latest/faq/docs/faq/What-are-the-units-used-in-the-simulation.html>
"""
import typing as typ import typing as typ
import sympy.physics.units as spu import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils.pydantic_sympy import SympyExpr
from .socket_types import SocketType as ST # noqa: N817
from .socket_units import SOCKET_UNITS
def _socket_units(socket_type):
return SOCKET_UNITS[socket_type]['values']
UnitSystem: typ.TypeAlias = dict[ST, SympyExpr]
#################### ####################
# - Unit Systems # - Unit Systems
#################### ####################
UNITS_BLENDER: UnitSystem = { _PT: typ.TypeAlias = spux.PhysicalType
ST.PhysicalTime: spu.picosecond, UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
ST.PhysicalAngle: spu.radian, # Global
ST.PhysicalLength: spu.micrometer, _PT.Time: spu.picosecond,
ST.PhysicalArea: spu.micrometer**2, _PT.Freq: spux.terahertz,
ST.PhysicalVolume: spu.micrometer**3, _PT.AngFreq: spu.radian * spux.terahertz,
ST.PhysicalPoint2D: spu.micrometer, # Cartesian
ST.PhysicalPoint3D: spu.micrometer, _PT.Length: spu.micrometer,
ST.PhysicalSize2D: spu.micrometer, _PT.Area: spu.micrometer**2,
ST.PhysicalSize3D: spu.micrometer, _PT.Volume: spu.micrometer**3,
ST.PhysicalMass: spu.microgram, # Energy
ST.PhysicalSpeed: spu.um / spu.second, _PT.PowerFlux: spu.watt / spu.um**2,
ST.PhysicalAccelScalar: spu.um / spu.second**2, # Electrodynamics
ST.PhysicalForceScalar: spux.micronewton, _PT.CurrentDensity: spu.ampere / spu.um**2,
ST.PhysicalAccel3D: spu.um / spu.second**2, _PT.Conductivity: spu.siemens / spu.um,
ST.PhysicalForce3D: spux.micronewton, _PT.PoyntingVector: spu.watt / spu.um**2,
ST.PhysicalFreq: spux.terahertz, _PT.EField: spu.volt / spu.um,
ST.PhysicalPol: spu.radian, _PT.HField: spu.ampere / spu.um,
# Mechanical
_PT.Vel: spu.um / spu.second,
_PT.Accel: spu.um / spu.second,
_PT.Mass: spu.microgram,
_PT.Force: spux.micronewton,
# Luminal
# Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
} ## TODO: Load (dynamically?) from addon preferences } ## TODO: Load (dynamically?) from addon preferences
UNITS_TIDY3D: UnitSystem = { UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
## https://docs.flexcompute.com/projects/tidy3d/en/latest/faq/docs/faq/What-are-the-units-used-in-the-simulation.html # Global
ST.PhysicalTime: spu.second, # Cartesian
ST.PhysicalAngle: spu.radian, _PT.Length: spu.um,
ST.PhysicalLength: spu.micrometer, _PT.Area: spu.um**2,
ST.PhysicalArea: spu.micrometer**2, _PT.Volume: spu.um**3,
ST.PhysicalVolume: spu.micrometer**3, # Mechanical
ST.PhysicalPoint2D: spu.micrometer, _PT.Vel: spu.um / spu.second,
ST.PhysicalPoint3D: spu.micrometer, _PT.Accel: spu.um / spu.second,
ST.PhysicalSize2D: spu.micrometer, # Energy
ST.PhysicalSize3D: spu.micrometer, _PT.PowerFlux: spu.watt / spu.um**2,
ST.PhysicalMass: spu.microgram, # Electrodynamics
ST.PhysicalSpeed: spu.um / spu.second, _PT.CurrentDensity: spu.ampere / spu.um**2,
ST.PhysicalAccelScalar: spu.um / spu.second**2, _PT.Conductivity: spu.siemens / spu.um,
ST.PhysicalForceScalar: spux.micronewton, _PT.PoyntingVector: spu.watt / spu.um**2,
ST.PhysicalAccel3D: spu.um / spu.second**2, _PT.EField: spu.volt / spu.um,
ST.PhysicalForce3D: spux.micronewton, _PT.HField: spu.ampere / spu.um,
ST.PhysicalFreq: spu.hertz, # Luminal
ST.PhysicalPol: spu.radian, # Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
} }

View File

@ -101,6 +101,10 @@ class ManagedBLMesh(base.ManagedObj):
#################### ####################
# - Methods # - Methods
#################### ####################
@property
def exists(self) -> bool:
return bpy.data.objects.get(self.name) is not None
def show_preview(self) -> None: def show_preview(self) -> None:
"""Moves the managed Blender object to the preview collection. """Moves the managed Blender object to the preview collection.

View File

@ -4,7 +4,8 @@ import typing as typ
import bpy import bpy
from blender_maxwell.utils import analyze_geonodes, logger from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .. import bl_socket_map from .. import bl_socket_map
from .. import contracts as ct from .. import contracts as ct
@ -25,13 +26,12 @@ class ModifierAttrsNODES(typ.TypedDict):
node_group: The GeoNodes group to use in the modifier. node_group: The GeoNodes group to use in the modifier.
unit_system: The unit system used by the GeoNodes output. unit_system: The unit system used by the GeoNodes output.
Generally, `ct.UNITS_BLENDER` is a good choice. Generally, `ct.UNITS_BLENDER` is a good choice.
inputs: Values to associate with each GeoNodes interface socket. inputs: Values to associate with each GeoNodes interface socket name.
Use `analyze_geonodes.interface(..., direc='INPUT')` to determine acceptable values.
""" """
node_group: bpy.types.GeometryNodeTree node_group: bpy.types.GeometryNodeTree
unit_system: UnitSystem unit_system: UnitSystem
inputs: dict[ct.BLNodeTreeInterfaceID, typ.Any] inputs: dict[ct.SocketName, typ.Any]
class ModifierAttrsARRAY(typ.TypedDict): class ModifierAttrsARRAY(typ.TypedDict):
@ -47,7 +47,7 @@ MODIFIER_NAMES = {
#################### ####################
# - Read Modifier Information # - Read Modifier
#################### ####################
def read_modifier(bl_modifier: bpy.types.Modifier) -> ModifierAttrs: def read_modifier(bl_modifier: bpy.types.Modifier) -> ModifierAttrs:
if bl_modifier.type == 'NODES': if bl_modifier.type == 'NODES':
@ -62,7 +62,7 @@ def read_modifier(bl_modifier: bpy.types.Modifier) -> ModifierAttrs:
#################### ####################
# - Write Modifier Information # - Write Modifier: GeoNodes
#################### ####################
def write_modifier_geonodes( def write_modifier_geonodes(
bl_modifier: bpy.types.NodesModifier, bl_modifier: bpy.types.NodesModifier,
@ -78,6 +78,7 @@ def write_modifier_geonodes(
True if the modifier was altered. True if the modifier was altered.
""" """
modifier_altered = False modifier_altered = False
# Alter GeoNodes Group # Alter GeoNodes Group
if bl_modifier.node_group != modifier_attrs['node_group']: if bl_modifier.node_group != modifier_attrs['node_group']:
log.info( log.info(
@ -89,53 +90,22 @@ def write_modifier_geonodes(
modifier_altered = True modifier_altered = True
# Alter GeoNodes Modifier Inputs # Alter GeoNodes Modifier Inputs
## First we retrieve the interface items by-Socket Name socket_infos = bl_socket_map.info_from_geonodes(bl_modifier.node_group)
geonodes_interface = analyze_geonodes.interface(
bl_modifier.node_group, direc='INPUT' for socket_name in modifier_attrs['inputs']:
iface_id = socket_infos[socket_name].bl_isocket_identifier
bl_modifier[iface_id] = spux.scale_to_unit_system(
modifier_attrs['inputs'][socket_name], modifier_attrs['unit_system']
) )
for (
socket_name,
value,
) in modifier_attrs['inputs'].items():
# Compute Writable BL Socket Value
## Analyzes the socket and unitsys to prep a ready-to-write value.
## Write directly to the modifier dict.
bl_socket_value = bl_socket_map.writable_bl_socket_value(
geonodes_interface[socket_name],
value,
unit_system=modifier_attrs['unit_system'],
allow_unit_not_in_unit_system=True,
)
# Compute Interface ID from Socket Name
## We can't index the modifier by socket name; only by Interface ID.
## Still, we require that socket names are unique.
iface_id = geonodes_interface[socket_name].identifier
# IF List-Like: Alter Differing Elements
if isinstance(bl_socket_value, tuple):
for i, bl_socket_subvalue in enumerate(bl_socket_value):
if bl_modifier[iface_id][i] != bl_socket_subvalue:
bl_modifier[iface_id][i] = bl_socket_subvalue
modifier_altered = True modifier_altered = True
## TODO: More fine-grained alterations
# IF int/float Mismatch: Assign Float-Cast of Integer return modifier_altered # noqa: RET504
## Blender is strict; only floats can set float vals.
## We are less strict; if the user passes an int, that's okay.
elif isinstance(bl_socket_value, int) and isinstance(
bl_modifier[iface_id],
float,
):
bl_modifier[iface_id] = float(bl_socket_value)
modifier_altered = True
else:
## TODO: Whitelist what can be here. I'm done with the TypeErrors.
bl_modifier[iface_id] = bl_socket_value
modifier_altered = True
return modifier_altered
####################
# - Write Modifier
####################
def write_modifier( def write_modifier(
bl_modifier: bpy.types.Modifier, bl_modifier: bpy.types.Modifier,
modifier_attrs: ModifierAttrs, modifier_attrs: ModifierAttrs,
@ -184,8 +154,11 @@ class ManagedBLModifier(base.ManagedObj):
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
def bl_select(self) -> None: pass def bl_select(self) -> None:
def hide_preview(self) -> None: pass pass
def hide_preview(self) -> None:
pass
#################### ####################
# - Deallocation # - Deallocation
@ -255,7 +228,8 @@ class ManagedBLModifier(base.ManagedObj):
type=modifier_type, type=modifier_type,
) )
if modifier_altered := write_modifier(bl_modifier, modifier_attrs): modifier_altered = write_modifier(bl_modifier, modifier_attrs)
if modifier_altered:
bl_object.data.update() bl_object.data.update()
return bl_modifier return bl_modifier

View File

@ -41,18 +41,17 @@ class ExtractDataNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()}, 'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()},
'Monitor Data': {'Monitor Data': sockets.DataSocketDef(format='monitor_data')}, 'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
} }
output_socket_sets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Sim Data': {'Monitor Data': sockets.DataSocketDef(format='monitor_data')}, 'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
'Monitor Data': {'Data': sockets.DataSocketDef(format='jax')}, 'Monitor Data': {'Expr': sockets.ExprSocketDef()},
} }
#################### ####################
# - Properties # - Properties
#################### ####################
extract_filter: enum.Enum = bl_cache.BLField( extract_filter: enum.Enum = bl_cache.BLField(
None,
prop_ui=True, prop_ui=True,
enum_cb=lambda self, _: self.search_extract_filters(), enum_cb=lambda self, _: self.search_extract_filters(),
) )
@ -62,7 +61,7 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
@property @property
def sim_data(self) -> td.SimulationData | None: def sim_data(self) -> td.SimulationData | None:
"""Computes the (cached) simulation data from the input socket. """Extracts the simulation data from the input socket.
Return: Return:
Either the simulation data, if available, or None. Either the simulation data, if available, or None.
@ -70,14 +69,15 @@ class ExtractDataNode(base.MaxwellSimNode):
sim_data = self._compute_input( sim_data = self._compute_input(
'Sim Data', kind=ct.FlowKind.Value, optional=True 'Sim Data', kind=ct.FlowKind.Value, optional=True
) )
if not ct.FlowSignal.check(sim_data): has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data:
return sim_data return sim_data
return None return None
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def sim_data_monitor_nametype(self) -> dict[str, str] | None: def sim_data_monitor_nametype(self) -> dict[str, str] | None:
"""For simulation data, computes and and caches a map from name to "type". """For simulation data, deduces a map from the monitor name to the monitor "type".
Return: Return:
The name to type of monitors in the simulation data. The name to type of monitors in the simulation data.
@ -95,7 +95,7 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
@property @property
def monitor_data(self) -> TDMonitorData | None: def monitor_data(self) -> TDMonitorData | None:
"""Computes the (cached) monitor data from the input socket. """Extracts the monitor data from the input socket.
Return: Return:
Either the monitor data, if available, or None. Either the monitor data, if available, or None.
@ -103,17 +103,26 @@ class ExtractDataNode(base.MaxwellSimNode):
monitor_data = self._compute_input( monitor_data = self._compute_input(
'Monitor Data', kind=ct.FlowKind.Value, optional=True 'Monitor Data', kind=ct.FlowKind.Value, optional=True
) )
if not ct.FlowSignal.check(monitor_data): has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data:
return monitor_data return monitor_data
return None return None
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def monitor_data_type(self) -> str | None: def monitor_data_type(self) -> str | None:
"""For monitor data, computes and caches the monitor "type". r"""For monitor data, deduces the monitor "type".
- **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain.
- **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor.
- **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain.
For planes, an explicit direction is defined.
For volumes, the the integral of all outgoing energy is stored.
- **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection.
- **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order.
Notes: Notes:
Should be invalidated with (before) `self.monitor_data_components`. Should be invalidated with (before) `self.monitor_data_attrs`.
Return: Return:
The "type" of the monitor, if available, else None. The "type" of the monitor, if available, else None.
@ -124,10 +133,10 @@ class ExtractDataNode(base.MaxwellSimNode):
return None return None
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def monitor_data_components(self) -> list[str] | None: def monitor_data_attrs(self) -> list[str] | None:
r"""For monitor data, computes and caches the component sof the monitor. r"""For monitor data, deduces the valid data-containing attributes.
The output depends entirely on the output of `self.monitor_data`. The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups.
- **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor. - **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor.
- **Permittivity**: Specifically `['xx', 'yy', 'zz']`. - **Permittivity**: Specifically `['xx', 'yy', 'zz']`.
@ -183,7 +192,7 @@ class ExtractDataNode(base.MaxwellSimNode):
"""Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`. """Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`.
Notes: Notes:
Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_components`, and (implicitly) `self.monitor_type`. Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`. See `bl_cache.BLField` for more on dynamic `EnumProperty`.
@ -198,16 +207,56 @@ class ExtractDataNode(base.MaxwellSimNode):
) )
] ]
if self.monitor_data_components is not None: if self.monitor_data_attrs is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [ return [
( (
component_name, monitor_attr,
component_name, monitor_attr,
f' {component_name[1]}-polarization of the {"electric" if component_name[0] == "E" else "magnetic"} field', f' {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'', '',
i, i,
) )
for i, component_name in enumerate(self.monitor_data_components) for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return [
(monitor_attr, monitor_attr, f' ε_{monitor_attr}', '', i)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return [
(
monitor_attr,
monitor_attr,
'Power flux integral through the plane / out of the volume',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
] ]
return [] return []
@ -215,9 +264,14 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self): def draw_label(self) -> None:
"""Show the extracted data (if any) in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
has_sim_data = self.sim_data_monitor_nametype is not None has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_components is not None has_monitor_data = self.monitor_data_attrs is not None
if has_sim_data or has_monitor_data: if has_sim_data or has_monitor_data:
return f'Extract: {self.extract_filter}' return f'Extract: {self.extract_filter}'
@ -245,11 +299,11 @@ class ExtractDataNode(base.MaxwellSimNode):
"""Invalidate the cached properties for sim data / monitor data, and reset the extraction filter.""" """Invalidate the cached properties for sim data / monitor data, and reset the extraction filter."""
self.sim_data_monitor_nametype = bl_cache.Signal.InvalidateCache self.sim_data_monitor_nametype = bl_cache.Signal.InvalidateCache
self.monitor_data_type = bl_cache.Signal.InvalidateCache self.monitor_data_type = bl_cache.Signal.InvalidateCache
self.monitor_data_components = bl_cache.Signal.InvalidateCache self.monitor_data_attrs = bl_cache.Signal.InvalidateCache
self.extract_filter = bl_cache.Signal.ResetEnumItems self.extract_filter = bl_cache.Signal.ResetEnumItems
#################### ####################
# - Output: Sim Data -> Monitor Data # - Output (Value): Sim Data -> Monitor Data
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
# Trigger # Trigger
@ -262,96 +316,84 @@ class ExtractDataNode(base.MaxwellSimNode):
def compute_monitor_data( def compute_monitor_data(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> TDMonitorData | ct.FlowSignal: ) -> TDMonitorData | ct.FlowSignal:
"""Compute `Monitor Data` by querying an attribute of `Sim Data`. """Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`.
Notes:
The attribute to query is read directly from `self.extract_filter`.
This is also the mechanism that protects from trying to reference an invalid attribute.
Returns: Returns:
Monitor data, if available, else `ct.FlowSignal.FlowPending`. Monitor data, if available, else `ct.FlowSignal.FlowPending`.
""" """
extract_filter = props['extract_filter']
sim_data = input_sockets['Sim Data'] sim_data = input_sockets['Sim Data']
has_sim_data = not ct.FlowSignal.check(sim_data) has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data and props['extract_filter'] != 'NONE': if has_sim_data and extract_filter is not None:
return input_sockets['Sim Data'].monitor_data[props['extract_filter']] return sim_data.monitor_data[extract_filter]
# Propagate NoFlow
if ct.FlowSignal.check_single(sim_data, ct.FlowSignal.NoFlow):
return ct.FlowSignal.NoFlow
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - Output: Monitor Data -> Data # - Output (Array): Monitor Data -> Expr
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
# Trigger # Trigger
'Data', 'Expr',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
# Loaded # Loaded
props={'extract_filter'}, props={'extract_filter'},
input_sockets={'Monitor Data'}, input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
) )
def compute_data( def compute_expr(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> jax.Array | ct.FlowSignal: ) -> jax.Array | ct.FlowSignal:
"""Compute `Data:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow`. """Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it.
Uses the internal `xarray` data returned by Tidy3D. Uses the internal `xarray` data returned by Tidy3D.
By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy. By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy.
Notes:
The attribute to query is read directly from `self.extract_filter`.
This is also the mechanism that protects from trying to reference an invalid attribute.
Used as the first part of the `LazyFuncValue` chain used for further array manipulations with Math nodes.
Returns: Returns:
The data array, if available, else `ct.FlowSignal.FlowPending`. The data array, if available, else `ct.FlowSignal.FlowPending`.
""" """
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data']) extract_filter = props['extract_filter']
monitor_data = input_sockets['Monitor Data']
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data and props['extract_filter'] != 'NONE': if has_monitor_data and extract_filter is not None:
xarray_data = getattr( xarray_data = getattr(monitor_data, extract_filter)
input_sockets['Monitor Data'], props['extract_filter']
)
return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None) return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
# Trigger # Trigger
'Data', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
# Loaded # Loaded
output_sockets={'Data'}, output_sockets={'Expr'},
output_socket_kinds={'Data': ct.FlowKind.Array}, output_socket_kinds={'Expr': ct.FlowKind.Array},
) )
def compute_extracted_data_lazy( def compute_extracted_data_lazy(
self, output_sockets: dict self, output_sockets: dict
) -> ct.LazyValueFuncFlow | None: ) -> ct.LazyValueFuncFlow | None:
"""Declare `Data:LazyValueFunc` by creating a simple function that directly wraps `Data:Array`. """Declare `Expr:LazyValueFunc` by creating a simple function that directly wraps `Expr:Array`.
Returns: Returns:
The composable function array, if available, else `ct.FlowSignal.FlowPending`. The composable function array, if available, else `ct.FlowSignal.FlowPending`.
""" """
has_output_data = not ct.FlowSignal.check(output_sockets['Data']) output_expr = output_sockets['Expr']
has_output_expr = not ct.FlowSignal.check(output_expr)
if has_output_data: if has_output_expr:
return ct.LazyValueFuncFlow( return ct.LazyValueFuncFlow(
func=lambda: output_sockets['Data'].values, supports_jax=True func=lambda: output_expr.values, supports_jax=True
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - Auxiliary: Monitor Data -> Data # - Auxiliary (Params): Monitor Data -> Expr
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
) )
def compute_data_params(self) -> ct.ParamsFlow: def compute_data_params(self) -> ct.ParamsFlow:
@ -362,6 +404,9 @@ class ExtractDataNode(base.MaxwellSimNode):
""" """
return ct.ParamsFlow() return ct.ParamsFlow()
####################
# - Auxiliary (Info): Monitor Data -> Expr
####################
@events.computes_output_socket( @events.computes_output_socket(
# Trigger # Trigger
'Data', 'Data',
@ -380,20 +425,21 @@ class ExtractDataNode(base.MaxwellSimNode):
Returns: Returns:
Information describing the `Data:LazyValueFunc`, if available, else `ct.FlowSignal.FlowPending`. Information describing the `Data:LazyValueFunc`, if available, else `ct.FlowSignal.FlowPending`.
""" """
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data']) monitor_data = input_sockets['Monitor Data']
monitor_data_type = props['monitor_data_type']
extract_filter = props['extract_filter']
has_monitor_data = not ct.FlowSignal.check(monitor_data)
# Retrieve XArray # Retrieve XArray
if has_monitor_data and props['extract_filter'] != 'NONE': if has_monitor_data and extract_filter is not None:
xarr = getattr(input_sockets['Monitor Data'], props['extract_filter']) xarr = getattr(monitor_data, extract_filter)
else: else:
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
info_output_name = props['extract_filter']
info_output_shape = None
# Compute InfoFlow from XArray # Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian ## XYZF: Field / Permittivity / FieldProjectionCartesian
if props['monitor_data_type'] in { if monitor_data_type in {
'Field', 'Field',
'Permittivity', 'Permittivity',
#'FieldProjectionCartesian', #'FieldProjectionCartesian',
@ -413,18 +459,16 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Complex, output_mathtype=spux.MathType.Complex,
output_unit=( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
if props['monitor_data_type'] == 'Field'
else None
), ),
) )
## XYZT: FieldTime ## XYZT: FieldTime
if props['monitor_data_type'] == 'FieldTime': if monitor_data_type == 'FieldTime':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['x', 'y', 'z', 't'], dim_names=['x', 'y', 'z', 't'],
dim_idx={ dim_idx={
@ -440,18 +484,16 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Complex, output_mathtype=spux.MathType.Complex,
output_unit=( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
if props['monitor_data_type'] == 'Field'
else None
), ),
) )
## F: Flux ## F: Flux
if props['monitor_data_type'] == 'Flux': if monitor_data_type == 'Flux':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['f'], dim_names=['f'],
dim_idx={ dim_idx={
@ -461,14 +503,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Real, output_mathtype=spux.MathType.Real,
output_unit=spu.watt, output_unit=spu.watt,
) )
## T: FluxTime ## T: FluxTime
if props['monitor_data_type'] == 'FluxTime': if monitor_data_type == 'FluxTime':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['t'], dim_names=['t'],
dim_idx={ dim_idx={
@ -478,14 +520,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Real, output_mathtype=spux.MathType.Real,
output_unit=spu.watt, output_unit=spu.watt,
) )
## RThetaPhiF: FieldProjectionAngle ## RThetaPhiF: FieldProjectionAngle
if props['monitor_data_type'] == 'FieldProjectionAngle': if monitor_data_type == 'FieldProjectionAngle':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['r', 'theta', 'phi', 'f'], dim_names=['r', 'theta', 'phi', 'f'],
dim_idx={ dim_idx={
@ -508,18 +550,18 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Real, output_mathtype=spux.MathType.Real,
output_unit=( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if extract_filter.startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
), ),
) )
## UxUyRF: FieldProjectionKSpace ## UxUyRF: FieldProjectionKSpace
if props['monitor_data_type'] == 'FieldProjectionKSpace': if monitor_data_type == 'FieldProjectionKSpace':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['ux', 'uy', 'r', 'f'], dim_names=['ux', 'uy', 'r', 'f'],
dim_idx={ dim_idx={
@ -540,18 +582,18 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Real, output_mathtype=spux.MathType.Real,
output_unit=( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if extract_filter.startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
), ),
) )
## OrderxOrderyF: Diffraction ## OrderxOrderyF: Diffraction
if props['monitor_data_type'] == 'Diffraction': if monitor_data_type == 'Diffraction':
return ct.InfoFlow( return ct.InfoFlow(
dim_names=['orders_x', 'orders_y', 'f'], dim_names=['orders_x', 'orders_y', 'f'],
dim_idx={ dim_idx={
@ -569,17 +611,17 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
output_name=props['extract_filter'], output_name=extract_filter,
output_shape=None, output_shape=None,
output_mathtype=spux.MathType.Real, output_mathtype=spux.MathType.Real,
output_unit=( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if extract_filter.startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
), ),
) )
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"'
raise RuntimeError(msg) raise RuntimeError(msg)

View File

@ -1,16 +1,16 @@
from . import filter_math, map_math, operate_math, reduce_math, transform_math from . import filter_math, map_math, operate_math # , #reduce_math, transform_math
BL_REGISTER = [ BL_REGISTER = [
*operate_math.BL_REGISTER,
*map_math.BL_REGISTER, *map_math.BL_REGISTER,
*filter_math.BL_REGISTER, *filter_math.BL_REGISTER,
*reduce_math.BL_REGISTER, # *reduce_math.BL_REGISTER,
*operate_math.BL_REGISTER, # *transform_math.BL_REGISTER,
*transform_math.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**operate_math.BL_NODES,
**map_math.BL_NODES, **map_math.BL_NODES,
**filter_math.BL_NODES, **filter_math.BL_NODES,
**reduce_math.BL_NODES, # **reduce_math.BL_NODES,
**operate_math.BL_NODES, # **transform_math.BL_NODES,
**transform_math.BL_NODES,
} }

View File

@ -4,10 +4,10 @@ import enum
import typing as typ import typing as typ
import bpy import bpy
import jax
import jax.numpy as jnp import jax.numpy as jnp
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
@ -16,6 +16,78 @@ from ... import base, events
log = logger.get(__name__) log = logger.get(__name__)
class FilterOperation(enum.StrEnum):
"""Valid operations for the `FilterMathNode`.
Attributes:
DimToVec: Shift last dimension to output.
DimsToMat: Shift last 2 dimensions to output.
PinLen1: Remove a len(1) dimension.
Pin: Remove a len(n) dimension by selecting a particular index.
Swap: Swap the positions of two dimensions.
"""
# Dimensions
PinLen1 = enum.auto()
Pin = enum.auto()
Swap = enum.auto()
# Interpret
DimToVec = enum.auto()
DimsToMat = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
FO = FilterOperation
return {
# Dimensions
FO.PinLen1: 'pinₐ =1',
FO.Pin: 'pinₐ ≈v',
FO.Swap: 'a₁ ↔ a₂',
# Interpret
FO.DimToVec: '→ Vector',
FO.DimsToMat: '→ Matrix',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def are_dims_valid(self, dim_0: int | None, dim_1: int | None):
return not (
(
dim_0 is None
and self
in [FilterOperation.PinLen1, FilterOperation.Pin, FilterOperation.Swap]
)
or (dim_1 is None and self == FilterOperation.Swap)
)
def jax_func(self, axis_0: int | None, axis_1: int | None):
return {
# Interpret
FilterOperation.DimToVec: lambda data: data,
FilterOperation.DimsToMat: lambda data: data,
# Dimensions
FilterOperation.PinLen1: lambda data: jnp.squeeze(data, axis_0),
FilterOperation.Pin: lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis_0
),
FilterOperation.Swap: lambda data: jnp.swapaxes(data, axis_0, axis_1),
}[self]
def transform_info(self, info: ct.InfoFlow, dim_0: str, dim_1: str):
return {
# Interpret
FilterOperation.DimToVec: lambda: info.shift_last_input,
FilterOperation.DimsToMat: lambda: info.shift_last_input.shift_last_input,
# Dimensions
FilterOperation.PinLen1: lambda: info.delete_dimension(dim_0),
FilterOperation.Pin: lambda: info.delete_dimension(dim_0),
FilterOperation.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
}[self]()
class FilterMathNode(base.MaxwellSimNode): class FilterMathNode(base.MaxwellSimNode):
r"""Applies a function that operates on the shape of the array. r"""Applies a function that operates on the shape of the array.
@ -38,21 +110,18 @@ class FilterMathNode(base.MaxwellSimNode):
bl_label = 'Filter Math' bl_label = 'Filter Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'Interpret': {},
'Dimensions': {},
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(),
} }
#################### ####################
# - Properties # - Properties
#################### ####################
operation: enum.Enum = bl_cache.BLField( operation: FilterOperation = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations() FilterOperation.PinLen1,
prop_ui=True,
) )
# Dimension Selection # Dimension Selection
@ -68,49 +137,26 @@ class FilterMathNode(base.MaxwellSimNode):
#################### ####################
@property @property
def data_info(self) -> ct.InfoFlow | None: def data_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Data', kind=ct.FlowKind.Info) info = self._compute_input('Expr', kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info): if not ct.FlowSignal.check(info):
return info return info
return None return None
#################### ####################
# - Operation Search # - Search Dimensions
####################
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'Interpret':
items += [
('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'),
('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'),
]
elif self.active_socket_set == 'Dimensions':
items += [
('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'),
(
'PIN',
'pinₐ ≈v',
'Remove a len(n) dimension by selecting an index',
),
('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'),
]
return [(*item, '', i) for i, item in enumerate(items)]
####################
# - Dimensions Search
#################### ####################
def search_dims(self) -> list[ct.BLEnumElement]: def search_dims(self) -> list[ct.BLEnumElement]:
if self.data_info is None: if self.data_info is None:
return [] return []
if self.operation == 'PIN_LEN_ONE': if self.operation == FilterOperation.PinLen1:
dims = [ dims = [
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1') (dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
for dim_name in self.data_info.dim_names for dim_name in self.data_info.dim_names
if self.data_info.dim_lens[dim_name] == 1 if self.data_info.dim_lens[dim_name] == 1
] ]
elif self.operation in ['PIN', 'SWAP']: elif self.operation in [FilterOperation.Pin, FilterOperation.Swap]:
dims = [ dims = [
(dim_name, dim_name, f'Dimension "{dim_name}"') (dim_name, dim_name, f'Dimension "{dim_name}"')
for dim_name in self.data_info.dim_names for dim_name in self.data_info.dim_names
@ -124,12 +170,13 @@ class FilterMathNode(base.MaxwellSimNode):
# - UI # - UI
#################### ####################
def draw_label(self): def draw_label(self):
FO = FilterOperation
labels = { labels = {
'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)', FO.PinLen1: lambda: f'Filter: Pin {self.dim_0} (len=1)',
'PIN': lambda: f'Filter: Pin {self.dim_0}', FO.Pin: lambda: f'Filter: Pin {self.dim_0}',
'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}', FO.Swap: lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
'DIM_TO_VEC': lambda: 'Filter: -> Vector', FO.DimToVec: lambda: 'Filter: -> Vector',
'DIMS_TO_MAT': lambda: 'Filter: -> Matrix', FO.DimsToMat: lambda: 'Filter: -> Matrix',
} }
if (label := labels.get(self.operation)) is not None: if (label := labels.get(self.operation)) is not None:
@ -141,10 +188,10 @@ class FilterMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
if self.active_socket_set == 'Dimensions': if self.active_socket_set == 'Dimensions':
if self.operation in ['PIN_LEN_ONE', 'PIN']: if self.operation in [FilterOperation.PinLen1, FilterOperation.Pin]:
layout.prop(self, self.blfields['dim_0'], text='') layout.prop(self, self.blfields['dim_0'], text='')
if self.operation == 'SWAP': if self.operation == FilterOperation.Swap:
row = layout.row(align=True) row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='') row.prop(self, self.blfields['dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='') row.prop(self, self.blfields['dim_1'], text='')
@ -152,215 +199,199 @@ class FilterMathNode(base.MaxwellSimNode):
#################### ####################
# - Events # - Events
#################### ####################
@events.on_value_changed(
prop_name='active_socket_set',
run_on_init=True,
)
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed( @events.on_value_changed(
# Trigger # Trigger
socket_name='Data', socket_name='Expr',
prop_name={'active_socket_set', 'operation'}, prop_name={'operation'},
run_on_init=True, run_on_init=True,
# Loaded
props={'operation'},
) )
def on_any_change(self, props: dict) -> None: def on_input_changed(self) -> None:
self.dim_0 = bl_cache.Signal.ResetEnumItems self.dim_0 = bl_cache.Signal.ResetEnumItems
self.dim_1 = bl_cache.Signal.ResetEnumItems self.dim_1 = bl_cache.Signal.ResetEnumItems
@events.on_value_changed( @events.on_value_changed(
socket_name='Data', # Trigger
socket_name='Expr',
prop_name={'dim_0', 'dim_1', 'operation'}, prop_name={'dim_0', 'dim_1', 'operation'},
## run_on_init: Implicitly triggered. run_on_init=True,
# Loaded
props={'operation', 'dim_0', 'dim_1'}, props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def on_dim_change(self, props: dict, input_sockets: dict): def on_pin_changed(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data']) info = input_sockets['Expr']
if not has_data: has_info = not ct.FlowSignal.check(info)
if not has_info:
return return
# "Dimensions"|"PIN": Add/Remove Input Socket # "Dimensions"|"PIN": Add/Remove Input Socket
if props['operation'] == 'PIN' and props['dim_0'] != 'NONE': if props['operation'] == FilterOperation.Pin and props['dim_0'] is not None:
pinned_unit = info.dim_units[props['dim_0']]
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
wanted_mathtype = (
spux.MathType.Complex
if pinned_mathtype == spux.MathType.Complex
and spux.MathType.Complex in pinned_physical_type.valid_mathtypes
else spux.MathType.Real
)
# Get Current and Wanted Socket Defs # Get Current and Wanted Socket Defs
current_bl_socket = self.loose_input_sockets.get('Value') current_bl_socket = self.loose_input_sockets.get('Value')
wanted_socket_def = sockets.SOCKET_DEFS[
ct.unit_to_socket_type(
input_sockets['Data'].dim_idx[props['dim_0']].unit
)
]
# Determine Whether to Declare New Loose Input SOcket # Determine Whether to Declare New Loose Input SOcket
if ( if (
current_bl_socket is None current_bl_socket is None
or sockets.SOCKET_DEFS[current_bl_socket.socket_type] or current_bl_socket.shape is not None
!= wanted_socket_def or current_bl_socket.physical_type != pinned_physical_type
or current_bl_socket.mathtype != wanted_mathtype
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Value': wanted_socket_def(), 'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
shape=None,
physical_type=pinned_physical_type,
mathtype=wanted_mathtype,
default_unit=pinned_unit,
),
} }
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
#################### ####################
# - Compute: LazyValueFunc / Array # - Output
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
props={'operation', 'dim_0', 'dim_1'}, props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_lazy_value_func(self, props: dict, input_sockets: dict):
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
info = input_sockets['Data'][ct.FlowKind.Info]
# Check Flow
if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]):
return ct.FlowSignal.FlowPending
# Compute Function Arguments
operation = props['operation'] operation = props['operation']
if operation == 'NONE': lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc]
return ct.FlowSignal.FlowPending info = input_sockets['Expr'][ct.FlowKind.Info]
## Dimension(s) has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_info = not ct.FlowSignal.check(info)
# Dimension(s)
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': if (
return ct.FlowSignal.FlowPending has_lazy_value_func
if operation == 'SWAP' and dim_1 == 'NONE': and has_info
return ct.FlowSignal.FlowPending and operation is not None
and operation.are_dims_valid(dim_0, dim_1)
## Axis/Axes ):
axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None
axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None
# Compose Output Function
filter_func = {
# Dimensions
'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0),
'PIN': lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis_0
),
'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1),
# Interpret
'DIM_TO_VEC': lambda data: data,
'DIMS_TO_MAT': lambda data: data,
}[props['operation']]
return lazy_value_func.compose_within( return lazy_value_func.compose_within(
filter_func, operation.jax_func(axis_0, axis_1),
enclosing_func_args=[int] if operation == 'PIN' else [], enclosing_func_args=[int] if operation == 'PIN' else [],
supports_jax=True, supports_jax=True,
) )
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
output_sockets={'Data'},
output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
# Check Flow
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return ct.ArrayFlow( @events.computes_output_socket(
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs), 'Expr',
unit=None, kind=ct.FlowKind.Array,
output_sockets={'Expr'},
output_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
) )
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Expr'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
),
)
return ct.FlowSignal.FlowPending
#################### ####################
# - Compute Auxiliary: Info # - Auxiliary: Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'dim_0', 'dim_1', 'operation'}, props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
info = input_sockets['Data'] operation = props['operation']
info = input_sockets['Expr']
# Check Flow has_info = not ct.FlowSignal.check(info)
if ct.FlowSignal.check(info):
return ct.FlowSignal.FlowPending
# Collect Information # Dimension(s)
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
if (
has_info
and operation is not None
and operation.are_dims_valid(dim_0, dim_1)
):
return operation.transform_info(info, dim_0, dim_1)
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
return ct.FlowSignal.FlowPending
return {
# Dimensions
'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0),
'PIN': lambda: info.delete_dimension(dim_0),
'SWAP': lambda: info.swap_dimensions(dim_0, dim_1),
# Interpret
'DIM_TO_VEC': lambda: info.shift_last_input,
'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input,
}[props['operation']]()
#################### ####################
# - Compute Auxiliary: Info # - Auxiliary: Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
props={'dim_0', 'dim_1', 'operation'}, props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data', 'Value'}, input_sockets={'Expr', 'Value'},
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}}, input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Value': True}, input_sockets_optional={'Value': True},
) )
def compute_composed_params( def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow:
self, props: dict, input_sockets: dict operation = props['operation']
) -> ct.ParamsFlow: info = input_sockets['Expr'][ct.FlowKind.Info]
info = input_sockets['Data'][ct.FlowKind.Info] params = input_sockets['Expr'][ct.FlowKind.Params]
params = input_sockets['Data'][ct.FlowKind.Params]
# Check Flow has_info = not ct.FlowSignal.check(info)
if any(ct.FlowSignal.check(inp) for inp in [info, params]): has_params = not ct.FlowSignal.check(params)
return ct.FlowSignal.FlowPending
# Collect Information # Dimension(s)
## Dimensions
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
if (
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': has_info
return ct.FlowSignal.FlowPending and has_params
if props['operation'] == 'SWAP' and dim_1 == 'NONE': and operation is not None
return ct.FlowSignal.FlowPending and operation.are_dims_valid(dim_0, dim_1)
):
## Pinned Value ## Pinned Value
pinned_value = input_sockets['Value'] pinned_value = input_sockets['Value']
has_pinned_value = not ct.FlowSignal.check(pinned_value) has_pinned_value = not ct.FlowSignal.check(pinned_value)
if props['operation'] == 'PIN' and has_pinned_value: if props['operation'] == 'PIN' and has_pinned_value:
# Compute IDX Corresponding to Dimension Index
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
input_sockets['Value'], require_sorted=True pinned_value, require_sorted=True
) )
return params.compose_within(enclosing_func_args=[nearest_idx_to_value]) return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
return params return params
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -20,6 +20,248 @@ log = logger.get(__name__)
X_COMPLEX = sp.Symbol('x', complex=True) X_COMPLEX = sp.Symbol('x', complex=True)
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
Attributes:
UserExpr: Use a user-provided mapping expression.
Real: Compute the real part of the input.
Imag: Compute the imaginary part of the input.
Abs: Compute the absolute value of the input.
Sq: Square the input.
Sqrt: Compute the (principal) square root of the input.
InvSqrt: Compute the inverse square root of the input.
Cos: Compute the cosine of the input.
Sin: Compute the sine of the input.
Tan: Compute the tangent of the input.
Acos: Compute the inverse cosine of the input.
Asin: Compute the inverse sine of the input.
Atan: Compute the inverse tangent of the input.
Norm2: Compute the 2-norm (aka. length) of the input vector.
Det: Compute the determinant of the input matrix.
Cond: Compute the condition number of the input matrix.
NormFro: Compute the frobenius norm of the input matrix.
Rank: Compute the rank of the input matrix.
Diag: Compute the diagonal vector of the input matrix.
EigVals: Compute the eigenvalues vector of the input matrix.
SvdVals: Compute the singular values vector of the input matrix.
Inv: Compute the inverse matrix of the input matrix.
Tra: Compute the transpose matrix of the input matrix.
Qr: Compute the QR-factorized matrices of the input matrix.
Chol: Compute the Cholesky-factorized matrices of the input matrix.
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
# By User Expression
UserExpr = enum.auto()
# By Number
Real = enum.auto()
Imag = enum.auto()
Abs = enum.auto()
Sq = enum.auto()
Sqrt = enum.auto()
InvSqrt = enum.auto()
Cos = enum.auto()
Sin = enum.auto()
Tan = enum.auto()
Acos = enum.auto()
Asin = enum.auto()
Atan = enum.auto()
Sinc = enum.auto()
# By Vector
Norm2 = enum.auto()
# By Matrix
Det = enum.auto()
Cond = enum.auto()
NormFro = enum.auto()
Rank = enum.auto()
Diag = enum.auto()
EigVals = enum.auto()
SvdVals = enum.auto()
Inv = enum.auto()
Tra = enum.auto()
Qr = enum.auto()
Chol = enum.auto()
Svd = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
MO = MapOperation
return {
# By User Expression
MO.UserExpr: '*',
# By Number
MO.Real: '(v)',
MO.Imag: 'Im(v)',
MO.Abs: '|v|',
MO.Sq: '',
MO.Sqrt: '√v',
MO.InvSqrt: '1/√v',
MO.Cos: 'cos v',
MO.Sin: 'sin v',
MO.Tan: 'tan v',
MO.Acos: 'acos v',
MO.Asin: 'asin v',
MO.Atan: 'atan v',
MO.Sinc: 'sinc v',
# By Vector
MO.Norm2: '||v||₂',
# By Matrix
MO.Det: 'det V',
MO.Cond: 'κ(V)',
MO.NormFro: '||V||_F',
MO.Rank: 'rank V',
MO.Diag: 'diag V',
MO.EigVals: 'eigvals V',
MO.SvdVals: 'svdvals V',
MO.Inv: 'V⁻¹',
MO.Tra: 'Vt',
MO.Qr: 'qr V',
MO.Chol: 'chol V',
MO.Svd: 'svd V',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
MO = MapOperation
return (
str(self),
MO.to_name(self),
MO.to_name(self),
MO.to_icon(self),
i,
)
@staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
MO = MapOperation
# By Number
if shape is None:
return [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
# By Vector
if len(shape) == 1:
return [
MO.Norm2,
]
# By Matrix
if len(shape) == 2:
return [
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
return []
def jax_func(self, user_expr_func: ct.LazyValueFuncFlow | None = None):
MO = MapOperation
if self == MO.UserExpr and user_expr_func is not None:
return lambda data: user_expr_func.func(data)
return {
# By Number
MO.Real: lambda data: jnp.real(data),
MO.Imag: lambda data: jnp.imag(data),
MO.Abs: lambda data: jnp.abs(data),
MO.Sq: lambda data: jnp.square(data),
MO.Sqrt: lambda data: jnp.sqrt(data),
MO.InvSqrt: lambda data: 1 / jnp.sqrt(data),
MO.Cos: lambda data: jnp.cos(data),
MO.Sin: lambda data: jnp.sin(data),
MO.Tan: lambda data: jnp.tan(data),
MO.Acos: lambda data: jnp.acos(data),
MO.Asin: lambda data: jnp.asin(data),
MO.Atan: lambda data: jnp.atan(data),
MO.Sinc: lambda data: jnp.sinc(data),
# By Vector
# Vector -> Number
MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
# By Matrix
# Matrix -> Number
MO.Det: lambda data: jnp.linalg.det(data),
MO.Cond: lambda data: jnp.linalg.cond(data),
MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
MO.Rank: lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec
MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
MO.EigVals: lambda data: jnp.linalg.eigvals(data),
MO.SvdVals: lambda data: jnp.linalg.svdvals(data),
# Matrix -> Matrix
MO.Inv: lambda data: jnp.linalg.inv(data),
MO.Tra: lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices
MO.Qr: lambda data: jnp.linalg.qr(data),
MO.Chol: lambda data: jnp.linalg.cholesky(data),
MO.Svd: lambda data: jnp.linalg.svd(data),
}[self]
def transform_info(self, info: ct.InfoFlow):
MO = MapOperation
return {
# By User Expression
MO.UserExpr: '*',
# By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real),
# By Vector
MO.Norm2: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('v', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
),
# By Matrix
MO.Det: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=info.output_mathtype,
collapsed_unit=info.output_unit,
),
MO.Cond: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=None,
),
MO.NormFro: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
),
MO.Rank: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Integer,
collapsed_unit=None,
),
## TODO: Matrix -> Vec
## TODO: Matrix -> Matrices
}.get(self, info)
class MapMathNode(base.MaxwellSimNode): class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data. r"""Applies a function by-structure to the data.
@ -104,120 +346,46 @@ class MapMathNode(base.MaxwellSimNode):
bl_label = 'Map Math' bl_label = 'Map Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Element': {},
'By Vector': {},
'By Matrix': {},
'Expr': {
'Mapper': sockets.ExprSocketDef(
complex_symbols=[X_COMPLEX],
default_expr=X_COMPLEX,
),
},
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(),
} }
#################### ####################
# - Properties # - Properties
#################### ####################
operation: enum.Enum = bl_cache.BLField( operation: MapOperation = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations() prop_ui=True, enum_cb=lambda self, _: self.search_operations()
) )
def search_operations(self) -> list[ct.BLEnumElement]: @property
if self.active_socket_set == 'By Element': def expr_output_shape(self) -> ct.InfoFlow | None:
items = [ info = self._compute_input('Expr', kind=ct.FlowKind.Info)
# General has_info = not ct.FlowSignal.check(info)
('REAL', '(v)', 'real(v) (by el)'), if has_info:
('IMAG', 'Im(v)', 'imag(v) (by el)'), return info.output_shape
('ABS', '|v|', 'abs(v) (by el)'),
('SQ', '', 'v^2 (by el)'),
('SQRT', '√v', 'sqrt(v) (by el)'),
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
None,
# Trigonometry
('COS', 'cos v', 'cos(v) (by el)'),
('SIN', 'sin v', 'sin(v) (by el)'),
('TAN', 'tan v', 'tan(v) (by el)'),
('ACOS', 'acos v', 'acos(v) (by el)'),
('ASIN', 'asin v', 'asin(v) (by el)'),
('ATAN', 'atan v', 'atan(v) (by el)'),
]
elif self.active_socket_set in 'By Vector':
items = [
# Vector -> Number
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
]
elif self.active_socket_set == 'By Matrix':
items = [
# Matrix -> Number
('DET', 'det V', 'det(V) (by Mat)'),
('COND', 'κ(V)', 'cond(V) (by Mat)'),
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
('RANK', 'rank V', 'rank(V) (by Mat)'),
None,
# Matrix -> Array
('DIAG', 'diag V', 'diag(V) (by Mat)'),
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
None,
# Matrix -> Matrix
('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
('TRA', 'Vt', 'V^T (by Mat)'),
None,
# Matrix -> Matrices
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
('SVD', 'svd V', 'svd(V) -> U·Σ·V† (by Mat)'),
]
elif self.active_socket_set == 'Expr':
items = [('EXPR_EL', 'By Element', 'Expression-defined (by el)')]
else:
msg = f'Active socket set {self.active_socket_set} is unknown'
raise RuntimeError(msg)
return None
output_shape: tuple[int, ...] | None = bl_cache.BLField(None)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape is not None:
return [ return [
(*item, '', i) if item is not None else None for i, item in enumerate(items) operation.bl_enum_element(i)
for i, operation in enumerate(
MapOperation.by_element_shape(self.expr_output_shape)
)
] ]
return []
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self): def draw_label(self):
labels = { if self.operation is not None:
'REAL': '(v)', return 'Map: ' + MapOperation.to_name(self.operation)
'IMAG': 'Im(v)',
'ABS': '|v|',
'SQ': '',
'SQRT': '√v',
'INV_SQRT': '1/√v',
'COS': 'cos v',
'SIN': 'sin v',
'TAN': 'tan v',
'ACOS': 'acos v',
'ASIN': 'asin v',
'ATAN': 'atan v',
'NORM_2': '||v||₂',
'DET': 'det V',
'COND': 'κ(V)',
'NORM_FRO': '||V||_F',
'RANK': 'rank V',
'DIAG': 'diag V',
'EIG_VALS': 'eigvals V',
'SVD_VALS': 'svdvals V',
'INV': 'V⁻¹',
'TRA': 'Vt',
'QR': 'qr V',
'CHOL': 'chol V',
'SVD': 'svd V',
}
if (label := labels.get(self.operation)) is not None:
return 'Map: ' + label
return self.bl_label return self.bl_label
@ -228,106 +396,98 @@ class MapMathNode(base.MaxwellSimNode):
# - Events # - Events
#################### ####################
@events.on_value_changed( @events.on_value_changed(
prop_name='active_socket_set', # Trigger
socket_name='Expr',
run_on_init=True, run_on_init=True,
) )
def on_socket_set_changed(self): def on_input_changed(self):
if self.operation not in MapOperation.by_element_shape(self.expr_output_shape):
self.operation = bl_cache.Signal.ResetEnumItems self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
prop_name={'operation'},
run_on_init=True,
# Loaded
props={'operation'},
)
def on_operation_changed(self, props: dict) -> None:
operation = props['operation']
# UserExpr: Add/Remove Input Socket
if operation == MapOperation.UserExpr:
current_bl_socket = self.loose_input_sockets.get('Mapper')
if current_bl_socket is None:
self.loose_input_sockets = {
'Mapper': sockets.ExprSocketDef(
symbols={X_COMPLEX},
default_value=X_COMPLEX,
mathtype=spux.MathType.Complex,
),
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
#################### ####################
# - Compute: LazyValueFunc / Array # - Compute: LazyValueFunc / Array
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
props={'active_socket_set', 'operation'}, props={'operation'},
input_sockets={'Data', 'Mapper'}, input_sockets={'Expr', 'Mapper'},
input_socket_kinds={ input_socket_kinds={
'Data': ct.FlowKind.LazyValueFunc, 'Expr': ct.FlowKind.LazyValueFunc,
'Mapper': ct.FlowKind.LazyValueFunc, 'Mapper': ct.FlowKind.LazyValueFunc,
}, },
input_sockets_optional={'Mapper': True}, input_sockets_optional={'Mapper': True},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_data(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data']) operation = props['operation']
if ( expr = input_sockets['Expr']
not has_data mapper = input_sockets['Mapper']
or props['operation'] == 'NONE'
or (
props['active_socket_set'] == 'Expr'
and ct.FlowSignal.check(input_sockets['Mapper'])
)
):
return ct.FlowSignal.FlowPending
mapping_func: typ.Callable[[jax.Array], jax.Array] = { has_expr = not ct.FlowSignal.check(expr)
'By Element': { has_mapper = not ct.FlowSignal.check(expr)
'REAL': lambda data: jnp.real(data),
'IMAG': lambda data: jnp.imag(data),
'ABS': lambda data: jnp.abs(data),
'SQ': lambda data: jnp.square(data),
'SQRT': lambda data: jnp.sqrt(data),
'INV_SQRT': lambda data: 1 / jnp.sqrt(data),
'COS': lambda data: jnp.cos(data),
'SIN': lambda data: jnp.sin(data),
'TAN': lambda data: jnp.tan(data),
'ACOS': lambda data: jnp.acos(data),
'ASIN': lambda data: jnp.asin(data),
'ATAN': lambda data: jnp.atan(data),
'SINC': lambda data: jnp.sinc(data),
},
'By Vector': {
'NORM_2': lambda data: jnp.linalg.norm(data, ord=2, axis=-1),
},
'By Matrix': {
# Matrix -> Number
'DET': lambda data: jnp.linalg.det(data),
'COND': lambda data: jnp.linalg.cond(data),
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
'RANK': lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec
'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
'EIG_VALS': lambda data: jnp.linalg.eigvals(data),
'SVD_VALS': lambda data: jnp.linalg.svdvals(data),
# Matrix -> Matrix
'INV': lambda data: jnp.linalg.inv(data),
'TRA': lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices
'QR': lambda data: jnp.linalg.qr(data),
'CHOL': lambda data: jnp.linalg.cholesky(data),
'SVD': lambda data: jnp.linalg.svd(data),
},
'Expr': {
'EXPR_EL': lambda data: input_sockets['Mapper'].func(data),
},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data if has_expr and operation is not None:
return input_sockets['Data'].compose_within( if not has_mapper:
mapping_func, return expr.compose_within(
operation.jax_func(),
supports_jax=True, supports_jax=True,
) )
if operation == MapOperation.UserExpr and has_mapper:
return expr.compose_within(
operation.jax_func(user_expr_func=mapper),
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
output_sockets={'Data'}, output_sockets={'Expr'},
output_socket_kinds={ output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
}, },
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
) )
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params] params = output_sockets['Expr'][ct.FlowKind.Params]
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow( return ct.ArrayFlow(
values=lazy_value_func.func_jax( values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs *params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
), ),
unit=None,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
@ -341,60 +501,16 @@ class MapMathNode(base.MaxwellSimNode):
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Data': ct.FlowKind.Info},
) )
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
info = input_sockets['Data'] operation = props['operation']
if ct.FlowSignal.check(info): info = input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
if has_info and operation is not None:
return operation.transform_info(info)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Complex -> Real
if props['active_socket_set'] == 'By Element' and props['operation'] in [
'REAL',
'IMAG',
'ABS',
]:
return info.set_output_mathtype(spux.MathType.Real)
if props['active_socket_set'] == 'By Vector' and props['operation'] in [
'NORM_2'
]:
return {
'NORM_2': lambda: info.collapse_output(
collapsed_name=f'||{info.output_name}||₂',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
)
}[props['operation']]()
if props['active_socket_set'] == 'By Matrix' and props['operation'] in [
'DET',
'COND',
'NORM_FRO',
'RANK',
]:
return {
'DET': lambda: info.collapse_output(
collapsed_name=f'det {info.output_name}',
collapsed_mathtype=info.output_mathtype,
collapsed_unit=info.output_unit,
),
'COND': lambda: info.collapse_output(
collapsed_name=f'κ({info.output_name})',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=None,
),
'NORM_FRO': lambda: info.collapse_output(
collapsed_name=f'||({info.output_name}||_F',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
),
'RANK': lambda: info.collapse_output(
collapsed_name=f'rank {info.output_name}',
collapsed_mathtype=spux.MathType.Integer,
collapsed_unit=None,
),
}[props['operation']]()
return info
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,

View File

@ -14,6 +14,29 @@ from ... import base, events
log = logger.get(__name__) log = logger.get(__name__)
FUNCS = {
'ADD': lambda exprs: exprs[0] + exprs[1],
'SUB': lambda exprs: exprs[0] - exprs[1],
'MUL': lambda exprs: exprs[0] * exprs[1],
'DIV': lambda exprs: exprs[0] / exprs[1],
'POW': lambda exprs: exprs[0] ** exprs[1],
}
SP_FUNCS = FUNCS
JAX_FUNCS = FUNCS | {
# Number | *
'ATAN2': lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
'VEC_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
'CROSS': lambda exprs: jnp.cross(exprs[0], exprs[1]),
# Matrix | Vector
'MAT_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
'LIN_SOLVE': lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
'LSQ_SOLVE': lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Matrix | Matrix
'MAT_MAT_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}
class OperateMathNode(base.MaxwellSimNode): class OperateMathNode(base.MaxwellSimNode):
r"""Applies a function that depends on two inputs. r"""Applies a function that depends on two inputs.
@ -28,40 +51,12 @@ class OperateMathNode(base.MaxwellSimNode):
node_type = ct.NodeType.OperateMath node_type = ct.NodeType.OperateMath
bl_label = 'Operate Math' bl_label = 'Operate Math'
input_socket_sets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr | Expr': { 'Expr L': sockets.ExprSocketDef(show_info_columns=False),
'Expr L': sockets.ExprSocketDef(), 'Expr R': sockets.ExprSocketDef(show_info_columns=False),
'Expr R': sockets.ExprSocketDef(),
},
'Data | Data': {
'Data L': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
'Expr | Data': {
'Expr L': sockets.ExprSocketDef(),
'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
} }
output_socket_sets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr | Expr': {
'Expr': sockets.ExprSocketDef(), 'Expr': sockets.ExprSocketDef(),
},
'Data | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
'Expr | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
} }
#################### ####################
@ -77,15 +72,15 @@ class OperateMathNode(base.MaxwellSimNode):
def search_categories(self) -> list[ct.BLEnumElement]: def search_categories(self) -> list[ct.BLEnumElement]:
"""Deduce and return a list of valid categories for the current socket set and input data.""" """Deduce and return a list of valid categories for the current socket set and input data."""
data_l_info = self._compute_input( expr_l_info = self._compute_input(
'Data L', kind=ct.FlowKind.Info, optional=True 'Expr L', kind=ct.FlowKind.Info, optional=True
) )
data_r_info = self._compute_input( expr_r_info = self._compute_input(
'Data R', kind=ct.FlowKind.Info, optional=True 'Expr R', kind=ct.FlowKind.Info, optional=True
) )
has_data_l_info = not ct.FlowSignal.check(data_l_info) has_expr_l_info = not ct.FlowSignal.check(expr_l_info)
has_data_r_info = not ct.FlowSignal.check(data_r_info) has_expr_r_info = not ct.FlowSignal.check(expr_r_info)
# Categories by Socket Set # Categories by Socket Set
NUMBER_NUMBER = ( NUMBER_NUMBER = (
@ -120,64 +115,45 @@ class OperateMathNode(base.MaxwellSimNode):
) )
categories = [] categories = []
## Expr | Expr if has_expr_l_info and has_expr_r_info:
if self.active_socket_set == 'Expr | Expr':
return [NUMBER_NUMBER]
## Data | Data
if (
self.active_socket_set == 'Data | Data'
and has_data_l_info
and has_data_r_info
):
# Check Valid Broadcasting # Check Valid Broadcasting
## Number | Number ## Number | Number
if data_l_info.output_shape is None and data_r_info.output_shape is None: if expr_l_info.output_shape is None and expr_r_info.output_shape is None:
categories = [NUMBER_NUMBER] categories = [NUMBER_NUMBER]
## Number | Vector ## Number | Vector
elif ( elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 1 expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 1
): ):
categories = [NUMBER_VECTOR] categories = [NUMBER_VECTOR]
## Number | Matrix ## Number | Matrix
elif ( elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 2 expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 2
): # noqa: PLR2004 ):
categories = [NUMBER_MATRIX] categories = [NUMBER_MATRIX]
## Vector | Vector ## Vector | Vector
elif ( elif (
len(data_l_info.output_shape) == 1 len(expr_l_info.output_shape) == 1
and len(data_r_info.output_shape) == 1 and len(expr_r_info.output_shape) == 1
): ):
categories = [VECTOR_VECTOR] categories = [VECTOR_VECTOR]
## Matrix | Vector ## Matrix | Vector
elif ( elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004 len(expr_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 1 and len(expr_r_info.output_shape) == 1
): ):
categories = [MATRIX_VECTOR] categories = [MATRIX_VECTOR]
## Matrix | Matrix ## Matrix | Matrix
elif ( elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004 len(expr_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 2 # noqa: PLR2004 and len(expr_r_info.output_shape) == 2 # noqa: PLR2004
): ):
categories = [MATRIX_MATRIX] categories = [MATRIX_MATRIX]
## Expr | Data
if self.active_socket_set == 'Expr | Data' and has_data_r_info:
if data_r_info.output_shape is None:
categories = [NUMBER_NUMBER]
else:
categories = {
1: [NUMBER_NUMBER, NUMBER_VECTOR],
2: [NUMBER_NUMBER, NUMBER_MATRIX],
}[len(data_r_info.output_shape)]
return [ return [
(*category, '', i) if category is not None else None (*category, '', i) if category is not None else None
for i, category in enumerate(categories) for i, category in enumerate(categories)
@ -248,11 +224,10 @@ class OperateMathNode(base.MaxwellSimNode):
#################### ####################
@events.on_value_changed( @events.on_value_changed(
# Trigger # Trigger
socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'}, socket_name={'Expr L', 'Expr R'},
prop_name='active_socket_set',
run_on_init=True, run_on_init=True,
) )
def on_socket_set_changed(self) -> None: def on_socket_changed(self) -> None:
# Recompute Valid Categories # Recompute Valid Categories
self.category = bl_cache.Signal.ResetEnumItems self.category = bl_cache.Signal.ResetEnumItems
self.operation = bl_cache.Signal.ResetEnumItems self.operation = bl_cache.Signal.ResetEnumItems
@ -272,224 +247,135 @@ class OperateMathNode(base.MaxwellSimNode):
kind=ct.FlowKind.Value, kind=ct.FlowKind.Value,
props={'operation'}, props={'operation'},
input_sockets={'Expr L', 'Expr R'}, input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Expr R': ct.FlowKind.Value,
},
) )
def compute_expr(self, props: dict, input_sockets: dict): def compute_value(self, props: dict, input_sockets: dict):
operation = props['operation']
expr_l = input_sockets['Expr L'] expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R'] expr_r = input_sockets['Expr R']
return { has_expr_l_value = not ct.FlowSignal.check(expr_l)
'ADD': lambda: expr_l + expr_r, has_expr_r_value = not ct.FlowSignal.check(expr_r)
'SUB': lambda: expr_l - expr_r,
'MUL': lambda: expr_l * expr_r, if has_expr_l_value and has_expr_r_value and operation is not None:
'DIV': lambda: expr_l / expr_r, return SP_FUNCS[operation]([expr_l, expr_r])
'POW': lambda: expr_l**expr_r,
'ATAN2': lambda: sp.atan2(expr_r, expr_l), return ct.Flowsignal.FlowPending
}[props['operation']]()
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
props={'operation'}, props={'operation'},
input_sockets={'Data L', 'Data R'}, input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={ input_socket_kinds={
'Data L': ct.FlowKind.LazyValueFunc, 'Expr L': ct.FlowKind.LazyValueFunc,
'Data R': ct.FlowKind.LazyValueFunc, 'Expr R': ct.FlowKind.LazyValueFunc,
},
input_sockets_optional={
'Data L': True,
'Data R': True,
}, },
) )
def compute_data(self, props: dict, input_sockets: dict): def compose_func(self, props: dict, input_sockets: dict):
data_l = input_sockets['Data L'] operation = props['operation']
data_r = input_sockets['Data R'] if operation is None:
has_data_l = not ct.FlowSignal.check(data_l) return ct.FlowSignal.FlowPending
mapping_func = { expr_l = input_sockets['Expr L']
# Number | * expr_r = input_sockets['Expr R']
'ADD': lambda datas: datas[0] + datas[1],
'SUB': lambda datas: datas[0] - datas[1],
'MUL': lambda datas: datas[0] * datas[1],
'DIV': lambda datas: datas[0] / datas[1],
'POW': lambda datas: datas[0] ** datas[1],
'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]),
# Vector | Vector
'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'CROSS': lambda datas: jnp.cross(datas[0], datas[1]),
# Matrix | Vector
'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]),
'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]),
# Matrix | Matrix
'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
}[props['operation']]
# Compose by Socket Set has_expr_l = not ct.FlowSignal.check(expr_l)
## Data | Data has_expr_r = not ct.FlowSignal.check(expr_r)
if has_data_l:
return (data_l | data_r).compose_within( if has_expr_l and has_expr_r:
mapping_func, return (expr_l | expr_r).compose_within(
supports_jax=True, JAX_FUNCS[operation],
)
## Expr | Data
expr_l_lazy_value_func = ct.LazyValueFuncFlow(
func=lambda expr_l_value: expr_l_value,
func_args=[typ.Any],
supports_jax=True,
)
return (expr_l_lazy_value_func | data_r).compose_within(
mapping_func,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Array,
output_sockets={'Data'}, output_sockets={'Expr'},
output_socket_kinds={ output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
}, },
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
) )
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params] params = output_sockets['Expr'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params) has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params: if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow( return ct.ArrayFlow(
values=lazy_value_func.func_jax( values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs *params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
), ),
unit=None, unit=None,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Params
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Params,
props={'operation'},
input_sockets={'Expr L', 'Data L', 'Data R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Data L': {ct.FlowKind.Info, ct.FlowKind.Params},
'Data R': {ct.FlowKind.Info, ct.FlowKind.Params},
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
},
)
def compute_data_params(
self, props, input_sockets
) -> ct.ParamsFlow | ct.FlowSignal:
expr_l = input_sockets['Expr L']
data_l_info = input_sockets['Data L'][ct.FlowKind.Info]
data_l_params = input_sockets['Data L'][ct.FlowKind.Params]
data_r_info = input_sockets['Data R'][ct.FlowKind.Info]
data_r_params = input_sockets['Data R'][ct.FlowKind.Params]
has_expr_l = not ct.FlowSignal.check(expr_l)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_l_params = not ct.FlowSignal.check(data_l_params)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
has_data_r_params = not ct.FlowSignal.check(data_r_params)
# Compose by Socket Set
## Data | Data
if (
has_data_l_info
and has_data_l_params
and has_data_r_info
and has_data_r_params
):
return data_l_params | data_r_params
## Expr | Data
if has_expr_l and has_data_r_info and has_data_r_params:
operation = props['operation']
data_unit = data_r_info.output_unit
# By Operation
## Add/Sub: Scale to Output Unit
if operation in ['ADD', 'SUB', 'MUL', 'DIV']:
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
return data_r_params.compose_within(
enclosing_func_args=[value],
)
## Pow: Doesn't Exist (?)
## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function
if operation == 'POW':
return ct.FlowSignal.FlowPending
## atan2(): Only Length
## -> Implicitly presume that Data L/R use length units.
if operation == 'ATAN2':
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
return data_r_params.compose_within(
enclosing_func_args=[value],
)
return data_r_params.compose_within(
enclosing_func_args=[
spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
]
)
return ct.FlowSignal.FlowPending
#################### ####################
# - Auxiliary: Info # - Auxiliary: Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
input_sockets={'Expr L', 'Data L', 'Data R'}, props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={ input_socket_kinds={
'Expr L': ct.FlowKind.Value, 'Expr L': ct.FlowKind.Info,
'Data L': ct.FlowKind.Info, 'Expr R': ct.FlowKind.Info,
'Data R': ct.FlowKind.Info,
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
}, },
) )
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow: def compute_info(self, props, input_sockets) -> ct.InfoFlow:
expr_l = input_sockets['Expr L'] operation = props['operation']
data_l_info = input_sockets['Data L'] info_l = input_sockets['Expr L']
data_r_info = input_sockets['Data R'] info_r = input_sockets['Expr R']
has_expr_l = not ct.FlowSignal.check(expr_l) has_info_l = not ct.FlowSignal.check(info_l)
has_data_l_info = not ct.FlowSignal.check(data_l_info) has_info_r = not ct.FlowSignal.check(info_r)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
# Info by Socket Set # Return Info of RHS
## Data | Data ## -> Fundamentall, this is why 'category' only has the given options.
if has_data_l_info and has_data_r_info: ## -> Via 'category', we enforce that the operated-on structure is always RHS.
return data_r_info ## -> That makes it super duper easy to track info changes.
if has_info_l and has_info_r and operation is not None:
return info_r
## Expr | Data return ct.FlowSignal.FlowPending
if has_expr_l and has_data_r_info:
return data_r_info
####################
# - Auxiliary: Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Params,
'Expr R': ct.FlowKind.Params,
},
)
def compute_params(
self, props, input_sockets
) -> ct.ParamsFlow | ct.FlowSignal:
operation = props['operation']
params_l = input_sockets['Expr L']
params_r = input_sockets['Expr R']
has_params_l = not ct.FlowSignal.check(params_l)
has_params_r = not ct.FlowSignal.check(params_r)
if has_params_l and has_params_r and operation is not None:
return params_l | params_r
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -147,7 +147,7 @@ class VizTarget(enum.StrEnum):
@staticmethod @staticmethod
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
return { return {
'NONE': [], None: [],
VizMode.Hist1D: [VizTarget.Plot2D], VizMode.Hist1D: [VizTarget.Plot2D],
VizMode.BoxPlot1D: [VizTarget.Plot2D], VizMode.BoxPlot1D: [VizTarget.Plot2D],
VizMode.Curve2D: [VizTarget.Plot2D], VizMode.Curve2D: [VizTarget.Plot2D],
@ -192,7 +192,7 @@ class VizNode(base.MaxwellSimNode):
# - Sockets # - Sockets
#################### ####################
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Preview': sockets.AnySocketDef(), 'Preview': sockets.AnySocketDef(),
@ -222,7 +222,7 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
@property @property
def data_info(self) -> ct.InfoFlow: def data_info(self) -> ct.InfoFlow:
return self._compute_input('Data', kind=ct.FlowKind.Info) return self._compute_input('Expr', kind=ct.FlowKind.Info)
def search_modes(self) -> list[ct.BLEnumElement]: def search_modes(self) -> list[ct.BLEnumElement]:
if not ct.FlowSignal.check(self.data_info): if not ct.FlowSignal.check(self.data_info):
@ -243,7 +243,7 @@ class VizNode(base.MaxwellSimNode):
## - Target Searcher ## - Target Searcher
##################### #####################
def search_targets(self) -> list[ct.BLEnumElement]: def search_targets(self) -> list[ct.BLEnumElement]:
if self.viz_mode != 'NONE': if self.viz_mode is not None:
return [ return [
( (
viz_target, viz_target,
@ -271,15 +271,15 @@ class VizNode(base.MaxwellSimNode):
# - Events # - Events
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name='Data', socket_name='Expr',
input_sockets={'Data'}, input_sockets={'Expr'},
run_on_init=True, run_on_init=True,
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Data': True}, input_sockets_optional={'Expr': True},
) )
def on_any_changed(self, input_sockets: dict): def on_any_changed(self, input_sockets: dict):
if not ct.FlowSignal.check_single( if not ct.FlowSignal.check_single(
input_sockets['Data'], ct.FlowSignal.FlowPending input_sockets['Expr'], ct.FlowSignal.FlowPending
): ):
self.viz_mode = bl_cache.Signal.ResetEnumItems self.viz_mode = bl_cache.Signal.ResetEnumItems
self.viz_target = bl_cache.Signal.ResetEnumItems self.viz_target = bl_cache.Signal.ResetEnumItems
@ -297,8 +297,8 @@ class VizNode(base.MaxwellSimNode):
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'}, props={'viz_mode', 'viz_target', 'colormap'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': {ct.FlowKind.Array, ct.FlowKind.Info}}, input_socket_kinds={'Expr': {ct.FlowKind.Array, ct.FlowKind.Info}},
stop_propagation=True, stop_propagation=True,
) )
def on_show_plot( def on_show_plot(
@ -308,14 +308,14 @@ class VizNode(base.MaxwellSimNode):
props: dict, props: dict,
): ):
# Retrieve Inputs # Retrieve Inputs
array_flow = input_sockets['Data'][ct.FlowKind.Array] array_flow = input_sockets['Expr'][ct.FlowKind.Array]
info = input_sockets['Data'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
# Check Flow # Check Flow
if ( if (
any(ct.FlowSignal.check(inp) for inp in [array_flow, info]) any(ct.FlowSignal.check(inp) for inp in [array_flow, info])
or props['viz_mode'] == 'NONE' or props['viz_mode'] is None
or props['viz_target'] == 'NONE' or props['viz_target'] is None
): ):
return return

View File

@ -1,7 +1,7 @@
from . import ( from . import (
constants, constants,
file_importers, file_importers,
unit_system, #unit_system,
wave_constant, wave_constant,
web_importers, web_importers,
) )
@ -10,14 +10,14 @@ from . import (
BL_REGISTER = [ BL_REGISTER = [
*wave_constant.BL_REGISTER, *wave_constant.BL_REGISTER,
*unit_system.BL_REGISTER, #*unit_system.BL_REGISTER,
*constants.BL_REGISTER, *constants.BL_REGISTER,
*web_importers.BL_REGISTER, *web_importers.BL_REGISTER,
*file_importers.BL_REGISTER, *file_importers.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**wave_constant.BL_NODES, **wave_constant.BL_NODES,
**unit_system.BL_NODES, #**unit_system.BL_NODES,
**constants.BL_NODES, **constants.BL_NODES,
**web_importers.BL_NODES, **web_importers.BL_NODES,
**file_importers.BL_NODES, **file_importers.BL_NODES,

View File

@ -1,32 +1,65 @@
import enum
import typing as typ import typing as typ
import bpy
from blender_maxwell.utils import bl_cache
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
from ... import base, events from ... import base, events
class NumberConstantNode(base.MaxwellSimNode): class NumberConstantNode(base.MaxwellSimNode):
"""A unitless number of configurable math type ex. integer, real, etc. .
Attributes:
mathtype: The math type to specify the number as.
"""
node_type = ct.NodeType.NumberConstant node_type = ct.NodeType.NumberConstant
bl_label = 'Numerical Constant' bl_label = 'Numerical Constant'
input_socket_sets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Integer': { 'Value': sockets.ExprSocketDef(),
'Value': sockets.IntegerNumberSocketDef(), }
}, output_sockets: typ.ClassVar = {
'Rational': { 'Value': sockets.ExprSocketDef(),
'Value': sockets.RationalNumberSocketDef(),
},
'Real': {
'Value': sockets.RealNumberSocketDef(),
},
'Complex': {
'Value': sockets.ComplexNumberSocketDef(),
},
} }
output_socket_sets = input_socket_sets
#################### ####################
# - Callbacks # - Properties
####################
mathtype: spux.MathType = bl_cache.BLField(
spux.MathType.Integer,
prop_ui=True,
)
size: spux.NumberSize1D = bl_cache.BLField(
spux.NumberSize1D.Scalar,
prop_ui=True,
)
####################
# - UI
####################
def draw_value(self, col: bpy.types.UILayout) -> None:
row = col.row(align=True)
row.prop(self, self.blfields['mathtype'], text='')
row.prop(self, self.blfields['size'], text='')
####################
# - Events
####################
@events.on_value_changed(prop_name={'mathtype', 'size'}, props={'mathtype', 'size'})
def on_mathtype_size_changed(self, props) -> None:
"""Change the input/output expression sockets to match the mathtype declared in the node."""
self.inputs['Value'].mathtype = props['mathtype']
self.inputs['Value'].shape = props['mathtype'].shape
####################
# - FlowKind
#################### ####################
@events.computes_output_socket('Value', input_sockets={'Value'}) @events.computes_output_socket('Value', input_sockets={'Value'})
def compute_value(self, input_sockets) -> typ.Any: def compute_value(self, input_sockets) -> typ.Any:

View File

@ -1,54 +1,91 @@
import enum
import typing as typ import typing as typ
import sympy as sp import sympy as sp
from blender_maxwell.utils import bl_cache
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts, sockets from .... import contracts, sockets
from ... import base, events from ... import base, events
class PhysicalConstantNode(base.MaxwellSimTreeNode): class PhysicalConstantNode(base.MaxwellSimTreeNode):
"""A number of configurable unit dimension, ex. time, length, etc. .
Attributes:
physical_type: The physical type to specify.
size: The size of the physical type, if it can be a vector.
"""
node_type = contracts.NodeType.PhysicalConstant node_type = contracts.NodeType.PhysicalConstant
bl_label = 'Physical Constant' bl_label = 'Physical Constant'
input_socket_sets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'time': { 'Value': sockets.ExprSocketDef(),
'value': sockets.PhysicalTimeSocketDef(
label='Time',
),
},
'angle': {
'value': sockets.PhysicalAngleSocketDef(
label='Angle',
),
},
'length': {
'value': sockets.PhysicalLengthSocketDef(
label='Length',
),
},
'area': {
'value': sockets.PhysicalAreaSocketDef(
label='Area',
),
},
'volume': {
'value': sockets.PhysicalVolumeSocketDef(
label='Volume',
),
},
'point_3d': {
'value': sockets.PhysicalPoint3DSocketDef(
label='3D Point',
),
},
'size_3d': {
'value': sockets.PhysicalSize3DSocketDef(
label='3D Size',
),
},
## I got bored so maybe the rest later
} }
output_socket_sets: typ.ClassVar = input_socket_sets output_sockets: typ.ClassVar = {
'Value': sockets.ExprSocketDef(),
}
####################
# - Properties
####################
physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.Time,
prop_ui=True,
)
mathtype: enum.Enum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_mathtypes(),
prop_ui=True,
)
size: enum.Enum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_sizes(),
prop_ui=True,
)
####################
# - Searchers
####################
def search_mathtypes(self):
return [
mathtype.bl_enum_element(i)
for i, mathtype in enumerate(self.physical_type.valid_mathtypes)
]
def search_sizes(self):
return [
spux.NumberSize1D.from_shape(shape).bl_enum_element(i)
for i, shape in enumerate(self.physical_type.valid_shapes)
if spux.NumberSize1D.supports_shape(shape)
]
####################
# - Events
####################
@events.on_value_changed(
prop_name={'physical_type', 'mathtype', 'size'},
props={'physical_type', 'mathtype', 'size'},
)
def on_mathtype_or_size_changed(self, props) -> None:
"""Change the input/output expression sockets to match the mathtype and size declared in the node."""
shape = spux.NumberSize1D(props['size']).shape
# Set Input Socket Physical Type
if self.inputs['Value'].physical_type != props['physical_type']:
self.inputs['Value'].physical_type = props['physical_type']
self.search_mathtypes = bl_cache.Signal.ResetEnumItems
self.search_sizes = bl_cache.Signal.ResetEnumItems
# Set Input Socket Math Type
if self.inputs['Value'].mathtype != props['mathtype']:
self.inputs['Value'].mathtype = props['mathtype']
# Set Input Socket Shape
if self.inputs['Value'].shape != shape:
self.inputs['Value'].shape = shape
#################### ####################
# - Callbacks # - Callbacks

View File

@ -33,8 +33,7 @@ class WaveConstantNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Wavelength': { 'Wavelength': {
'WL': sockets.ExprSocketDef( 'WL': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, physical_type=spux.PhysicalType.Length,
unit_dimension=spux.unit_dims.length,
# Defaults # Defaults
default_unit=spu.nm, default_unit=spu.nm,
default_value=500, default_value=500,
@ -46,7 +45,7 @@ class WaveConstantNode(base.MaxwellSimNode):
'Frequency': { 'Frequency': {
'Freq': sockets.ExprSocketDef( 'Freq': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, active_kind=ct.FlowKind.Value,
unit_dimension=spux.unit_dims.frequency, physical_type=spux.PhysicalType.Freq,
# Defaults # Defaults
default_unit=spux.THz, default_unit=spux.THz,
default_value=1, default_value=1,
@ -59,11 +58,11 @@ class WaveConstantNode(base.MaxwellSimNode):
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'WL': sockets.ExprSocketDef( 'WL': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, active_kind=ct.FlowKind.Value,
unit_dimension=spux.unit_dims.length, unit_dimension=spux.Dims.length,
), ),
'Freq': sockets.ExprSocketDef( 'Freq': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, active_kind=ct.FlowKind.Value,
unit_dimension=spux.unit_dims.frequency, unit_dimension=spux.Dims.frequency,
), ),
} }

View File

@ -1,6 +1,7 @@
import typing as typ import typing as typ
import sympy as sp import sympy as sp
import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
@ -25,22 +26,44 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
# - Sockets # - Sockets
#################### ####################
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Center': sockets.PhysicalPoint3DSocketDef(), 'Center': sockets.ExprSocketDef(
'Size': sockets.PhysicalSize3DSocketDef(), shape=(3,),
'Samples/Space': sockets.Integer3DVectorSocketDef( physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([10, 10, 10])
), ),
'Size': sockets.ExprSocketDef(
shape=(3,),
physical_type=spux.PhysicalType.Length,
),
'Spatial Subdivs': sockets.ExprSocketDef(
shape=(3,),
mathtype=spux.MathType.Integer,
default_value=sp.Matrix([10, 10, 10]),
),
## TODO: Pass a grid instead of size and resolution
## TODO: 1D (line), 2D (plane), 3D modes
} }
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Freq Domain': { 'Freq Domain': {
'Freqs': sockets.PhysicalFreqSocketDef( 'Freqs': sockets.ExprSocketDef(
is_array=True, active_kind=ct.FlowKind.LazyArrayRange,
physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
default_steps=100,
), ),
}, },
'Time Domain': { 'Time Domain': {
'Rec Start': sockets.PhysicalTimeSocketDef(), 'Time Range': sockets.ExprSocketDef(
'Rec Stop': sockets.PhysicalTimeSocketDef(default_value=200 * spux.fs), active_kind=ct.FlowKind.LazyArrayRange,
'Samples/Time': sockets.IntegerNumberSocketDef( physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
default_steps=2,
),
'Temporal Subdivs': sockets.ExprSocketDef(
mathtype=spux.MathType.Integer,
default_value=100, default_value=100,
), ),
}, },
@ -56,7 +79,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
} }
#################### ####################
# - Output Sockets # - Output
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Freq Monitor', 'Freq Monitor',
@ -64,7 +87,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
input_sockets={ input_sockets={
'Center', 'Center',
'Size', 'Size',
'Samples/Space', 'Spatial Subdivs',
'Freqs', 'Freqs',
}, },
input_socket_kinds={ input_socket_kinds={
@ -93,12 +116,12 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
center=input_sockets['Center'], center=input_sockets['Center'],
size=input_sockets['Size'], size=input_sockets['Size'],
name=props['sim_node_name'], name=props['sim_node_name'],
interval_space=tuple(input_sockets['Samples/Space']), interval_space=tuple(input_sockets['Spatial Subdivs']),
freqs=input_sockets['Freqs'].realize().values, freqs=input_sockets['Freqs'].realize().values,
) )
#################### ####################
# - Preview - Changes to Input Sockets # - Preview
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name={'Center', 'Size'}, socket_name={'Center', 'Size'},

View File

@ -1,6 +1,7 @@
import typing as typ import typing as typ
import sympy as sp import sympy as sp
import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
@ -23,23 +24,43 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
# - Sockets # - Sockets
#################### ####################
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Center': sockets.PhysicalPoint3DSocketDef(), 'Center': sockets.ExprSocketDef(
'Size': sockets.PhysicalSize3DSocketDef(), shape=(3,),
'Samples/Space': sockets.Integer3DVectorSocketDef( physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([10, 10, 10]) ),
'Size': sockets.ExprSocketDef(
shape=(3,),
physical_type=spux.PhysicalType.Length,
),
'Samples/Space': sockets.ExprSocketDef(
shape=(3,),
mathtype=spux.MathType.Integer,
default_value=sp.Matrix([10, 10, 10]),
), ),
'Direction': sockets.BoolSocketDef(), 'Direction': sockets.BoolSocketDef(),
} }
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Freq Domain': { 'Freq Domain': {
'Freqs': sockets.PhysicalFreqSocketDef( 'Freqs': sockets.ExprSocketDef(
is_array=True, active_kind=ct.FlowKind.LazyArrayRange,
physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
default_steps=100,
), ),
}, },
'Time Domain': { 'Time Domain': {
'Rec Start': sockets.PhysicalTimeSocketDef(), 'Time Range': sockets.ExprSocketDef(
'Rec Stop': sockets.PhysicalTimeSocketDef(default_value=200 * spux.fs), active_kind=ct.FlowKind.LazyArrayRange,
'Samples/Time': sockets.IntegerNumberSocketDef( physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
default_steps=2,
),
'Samples/Time': sockets.ExprSocketDef(
mathtype=spux.MathType.Integer,
default_value=100, default_value=100,
), ),
}, },

View File

@ -2,7 +2,7 @@ import typing as typ
import tidy3d as td import tidy3d as td
from blender_maxwell.utils import analyze_geonodes, logger from blender_maxwell.utils import bl_cache, logger
from ... import bl_socket_map, managed_objs, sockets from ... import bl_socket_map, managed_objs, sockets
from ... import contracts as ct from ... import contracts as ct
@ -20,9 +20,9 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
# - Sockets # - Sockets
#################### ####################
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(), 'Medium': sockets.MaxwellMediumSocketDef(),
'Center': sockets.PhysicalPoint3DSocketDef(), 'Center': sockets.PhysicalPoint3DSocketDef(),
'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Structure': sockets.MaxwellStructureSocketDef(), 'Structure': sockets.MaxwellStructureSocketDef(),
@ -34,7 +34,7 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
} }
#################### ####################
# - Event Methods # - Output
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Structure', 'Structure',
@ -46,9 +46,10 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
input_sockets: dict, input_sockets: dict,
managed_objs: dict, managed_objs: dict,
) -> td.Structure: ) -> td.Structure:
"""Computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`."""
# Simulate Input Value Change # Simulate Input Value Change
## This ensures that the mesh has been re-computed. ## This ensures that the mesh has been re-computed.
self.on_input_changed() self.on_input_socket_changed()
## TODO: mesh_as_arrays might not take the Center into account. ## TODO: mesh_as_arrays might not take the Center into account.
## - Alternatively, Tidy3D might have a way to transform? ## - Alternatively, Tidy3D might have a way to transform?
@ -62,96 +63,109 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
) )
#################### ####################
# - Event Methods # - Events: Preview Active Changed
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name={'GeoNodes', 'Center'},
prop_name='preview_active', prop_name='preview_active',
any_loose_input_socket=True,
run_on_init=True,
# Pass Data
props={'preview_active'}, props={'preview_active'},
input_sockets={'Center'},
managed_objs={'mesh'},
)
def on_preview_changed(self, props, input_sockets) -> None:
"""Enables/disables previewing of the GeoNodes-driven mesh, regardless of whether a particular GeoNodes tree is chosen."""
mesh = managed_objs['mesh']
# No Mesh: Create Empty Object
## Ensures that when there is mesh data, it'll be correctly previewed.
## Bit of a workaround - the idea is usually to make the MObj as needed.
if not mesh.exists:
center = input_sockets['Center']
_ = mesh.bl_object(location=center)
# Push Preview State to Managed Mesh
if props['preview_active']:
mesh.show_preview()
else:
mesh.hide_preview()
####################
# - Events: GN Input Changed
####################
@events.on_value_changed(
socket_name={'Center'},
any_loose_input_socket=True,
# Pass Data
managed_objs={'mesh', 'modifier'}, managed_objs={'mesh', 'modifier'},
input_sockets={'Center', 'GeoNodes'}, input_sockets={'Center', 'GeoNodes'},
all_loose_input_sockets=True, all_loose_input_sockets=True,
unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={'Center': 'BlenderUnits'}, scale_input_sockets={'Center': 'BlenderUnits'},
) )
def on_input_socket_changed(
self, input_sockets, loose_input_sockets, unit_systems
) -> None:
"""Pushes any change in GeoNodes-bound input sockets to the GeoNodes modifier.
Also pushes the `Center:Value` socket to govern the object's center in 3D space.
"""
geonodes = input_sockets['GeoNodes']
has_geonodes = not ct.FlowSignal.check(geonodes)
if has_geonodes:
mesh = managed_objs['mesh']
modifier = managed_objs['modifier']
center = input_sockets['Center']
unit_system = unit_systems['BlenderUnits']
# Push Loose Input Values to GeoNodes Modifier
modifier.bl_modifier(
mesh.bl_object(location=center),
'NODES',
{
'node_group': geonodes,
'inputs': loose_input_sockets,
'unit_system': unit_system,
},
)
####################
# - Events: GN Tree Changed
####################
@events.on_value_changed(
socket_name={'GeoNodes'},
# Pass Data
managed_objs={'mesh', 'modifier'},
input_sockets={'GeoNodes', 'Center'},
)
def on_input_changed( def on_input_changed(
self, self,
props: dict,
managed_objs: dict, managed_objs: dict,
input_sockets: dict, input_sockets: dict,
loose_input_sockets: dict,
unit_systems: dict,
) -> None: ) -> None:
# No GeoNodes: Remove Modifier (if any) """Declares new loose input sockets in response to a new GeoNodes tree (if any)."""
if (geonodes := input_sockets['GeoNodes']) is None: geonodes = input_sockets['GeoNodes']
if ( has_geonodes = not ct.FlowSignal.check(geonodes)
managed_objs['modifier'].name
in managed_objs['mesh'].bl_object().modifiers.keys().copy()
):
managed_objs['modifier'].free_from_bl_object(
managed_objs['mesh'].bl_object()
)
# Reset Loose Input Sockets if has_geonodes:
self.loose_input_sockets = {} mesh = managed_objs['mesh']
return modifier = managed_objs['modifier']
# No Loose Input Sockets: Create from GeoNodes Interface
## TODO: Other reasons to trigger re-filling loose_input_sockets.
if not loose_input_sockets:
# Retrieve the GeoNodes Interface
geonodes_interface = analyze_geonodes.interface(
input_sockets['GeoNodes'], direc='INPUT'
)
# Fill the Loose Input Sockets # Fill the Loose Input Sockets
## -> The SocketDefs contain the default values from the interface.
log.info( log.info(
'Initializing GeoNodes Structure Node "%s" from GeoNodes Group "%s"', 'Initializing GeoNodes Structure Node "%s" from GeoNodes Group "%s"',
self.bl_label, self.bl_label,
str(geonodes), str(geonodes),
) )
self.loose_input_sockets = { self.loose_input_sockets = bl_socket_map.sockets_from_geonodes(geonodes)
socket_name: bl_socket_map.socket_def_from_bl_socket(iface_socket)()
for socket_name, iface_socket in geonodes_interface.items()
}
# Set Loose Input Sockets to Interface (Default) Values ## -> The loose socket creation triggers 'on_input_socket_changed'
## Changing socket.value invokes recursion of this function.
## The else: below ensures that only one push occurs. elif self.loose_input_sockets:
## (well, one push per .value set, which simplifies to one push) self.loose_input_sockets = {}
log.info(
'Setting Loose Input Sockets of "%s" to GeoNodes Defaults', if modifier.name in mesh.bl_object().modifiers.keys().copy():
self.bl_label, modifier.free_from_bl_object(mesh.bl_object())
)
for socket_name in self.loose_input_sockets:
socket = self.inputs[socket_name]
socket.value = bl_socket_map.read_bl_socket_default_value(
geonodes_interface[socket_name],
unit_systems['BlenderUnits'],
allow_unit_not_in_unit_system=True,
)
log.info(
'Set Loose Input Sockets of "%s" to: %s',
self.bl_label,
str(self.loose_input_sockets),
)
else:
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
managed_objs['mesh'].bl_object(location=input_sockets['Center']),
'NODES',
{
'node_group': input_sockets['GeoNodes'],
'unit_system': unit_systems['BlenderUnits'],
'inputs': loose_input_sockets,
},
)
# Push Preview State
if props['preview_active']:
managed_objs['mesh'].show_preview()
#################### ####################

View File

@ -1,11 +1,11 @@
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from .. import contracts as ct from .. import contracts as ct
from . import basic, blender, maxwell, physical, tidy3d from . import basic, blender, expr, maxwell, physical, tidy3d
from .scan_socket_defs import scan_for_socket_defs from .scan_socket_defs import scan_for_socket_defs
log = logger.get(__name__) log = logger.get(__name__)
sockets_modules = [basic, physical, blender, maxwell, tidy3d] sockets_modules = [basic, blender, expr, maxwell, physical, tidy3d]
#################### ####################
# - Scan for SocketDefs # - Scan for SocketDefs
@ -28,21 +28,24 @@ for socket_type in ct.SocketType:
): ):
log.warning('Missing SocketDef for %s', socket_type.value) log.warning('Missing SocketDef for %s', socket_type.value)
#################### ####################
# - Exports # - Exports
#################### ####################
BL_REGISTER = [ BL_REGISTER = [
*basic.BL_REGISTER, *basic.BL_REGISTER,
*physical.BL_REGISTER,
*blender.BL_REGISTER, *blender.BL_REGISTER,
*expr.BL_REGISTER,
*maxwell.BL_REGISTER, *maxwell.BL_REGISTER,
*physical.BL_REGISTER,
*tidy3d.BL_REGISTER, *tidy3d.BL_REGISTER,
] ]
__all__ = [ __all__ = [
'basic', 'basic',
'physical',
'blender', 'blender',
'expr',
'maxwell', 'maxwell',
'physical',
'tidy3d', 'tidy3d',
] + [socket_def_type.__name__ for socket_def_type in SOCKET_DEFS.values()] ] + [socket_def_type.__name__ for socket_def_type in SOCKET_DEFS.values()]

View File

@ -5,10 +5,8 @@ from types import MappingProxyType
import bpy import bpy
import pydantic as pyd import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, serialize from blender_maxwell.utils import bl_cache, logger, serialize
from blender_maxwell.utils import extra_sympy_units as spux
from .. import contracts as ct from .. import contracts as ct
@ -126,7 +124,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
socket_color: tuple socket_color: tuple
# Options # Options
use_units: bool = False
use_prelock: bool = False use_prelock: bool = False
use_info_draw: bool = False use_info_draw: bool = False
@ -210,35 +207,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
'active_kind', bpy.props.StringProperty, default=str(ct.FlowKind.Value) 'active_kind', bpy.props.StringProperty, default=str(ct.FlowKind.Value)
) )
# Configure Use of Units
if cls.use_units:
if not (socket_units := ct.SOCKET_UNITS.get(cls.socket_type)):
msg = f'{cls.socket_type}: Tried to define "use_units", but there is no unit for {cls.socket_type} defined in "contracts.SOCKET_UNITS"'
raise RuntimeError(msg)
cls.set_prop(
'active_unit',
bpy.props.EnumProperty,
name='Unit',
items=[
(unit_name, spux.sp_to_str(unit_value), sp.srepr(unit_value))
for unit_name, unit_value in socket_units['values'].items()
],
default=socket_units['default'],
)
cls.set_prop(
'prev_active_unit',
bpy.props.StringProperty,
default=socket_units['default'],
)
####################
# - Units
####################
@property
def prev_unit(self) -> sp.Expr:
return self.possible_units[self.prev_active_unit]
#################### ####################
# - Property Event: On Update # - Property Event: On Update
#################### ####################
@ -250,36 +218,47 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
""" """
self.display_shape = ( self.display_shape = (
'SQUARE' if self.active_kind == ct.FlowKind.LazyValueRange else 'CIRCLE' 'SQUARE' if self.active_kind == ct.FlowKind.LazyValueRange else 'CIRCLE'
) + ('_DOT' if self.use_units else '') ) # + ('_DOT' if self.use_units else '')
## TODO: Valid Active Kinds should be a subset/subenum(?) of FlowKind ## TODO: Valid Active Kinds should be a subset/subenum(?) of FlowKind
def on_socket_prop_changed(self, prop_name: str) -> None:
"""Called when a property has been updated.
Notes:
Can be overridden if a socket needs to respond to a property change.
**Always prefer using node events when possible**.
Think very carefully before using this, and use it with the greatest of care.
Attributes:
prop_name: The name of the property that was changed.
"""
def on_prop_changed(self, prop_name: str, _: bpy.types.Context) -> None: def on_prop_changed(self, prop_name: str, _: bpy.types.Context) -> None:
"""Called when a property has been updated. """Called when a property has been updated.
Contrary to `node.on_prop_changed()`, socket-specific callbacks are baked into this function: Contrary to `node.on_prop_changed()`, socket-specific callbacks are baked into this function:
- **Active Kind** (`self.active_kind`): Sets the socket shape to reflect the active `FlowKind`. - **Active Kind** (`self.active_kind`): Sets the socket shape to reflect the active `FlowKind`.
- **Unit** (`self.unit`): Corrects the internal `FlowKind` representation to match the new unit.
Attributes: Attributes:
prop_name: The name of the property that was changed. prop_name: The name of the property that was changed.
""" """
# Property: Active Kind if hasattr(self, prop_name):
if prop_name == 'active_kind':
self._on_active_kind_changed()
elif prop_name == 'unit':
self._on_unit_changed()
# Valid Properties
elif hasattr(self, prop_name):
# Invalidate UI BLField Caches # Invalidate UI BLField Caches
if prop_name in self.ui_blfields: if prop_name in self.ui_blfields:
setattr(self, prop_name, bl_cache.Signal.InvalidateCache) setattr(self, prop_name, bl_cache.Signal.InvalidateCache)
# Property Callbacks: Active Kind
if prop_name == 'active_kind':
self._on_active_kind_changed()
# Property Callbacks: Per-Socket
self.on_socket_prop_changed(prop_name)
# Trigger Event # Trigger Event
self.trigger_event(ct.FlowEvent.DataChanged) self.trigger_event(ct.FlowEvent.DataChanged)
# Undefined Properties
else: else:
msg = f'Property {prop_name} not defined on socket {self.bl_label} ({self.socket_type})' msg = f'Property {prop_name} not defined on socket {self.bl_label} ({self.socket_type})'
raise RuntimeError(msg) raise RuntimeError(msg)
@ -760,7 +739,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
- **Locked** (`self.locked`): The UI will be unusable. - **Locked** (`self.locked`): The UI will be unusable.
- **Linked** (`self.is_linked`): Only the socket label will display. - **Linked** (`self.is_linked`): Only the socket label will display.
- **Use Units** (`self.use_units`): The currently active unit will display as a dropdown menu.
- **Use Prelock** (`self.use_prelock`): The "prelock" UI drawn with `self.draw_prelock()`, which shows **regardless of `self.locked`**. - **Use Prelock** (`self.use_prelock`): The "prelock" UI drawn with `self.draw_prelock()`, which shows **regardless of `self.locked`**.
- **FlowKind**: The `FlowKind`-specific UI corresponding to the current `self.active_kind`. - **FlowKind**: The `FlowKind`-specific UI corresponding to the current `self.active_kind`.
@ -786,16 +764,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
## Link Check ## Link Check
if self.is_linked: if self.is_linked:
self.draw_input_label_row(row, text) self.draw_input_label_row(row, text)
else:
# User Label Row (incl. Units)
if self.use_units:
split = row.split(factor=0.6, align=True)
_row = split.row(align=True)
self.draw_label_row(_row, text)
_col = split.column(align=True)
_col.prop(self, 'active_unit', text='')
else: else:
self.draw_label_row(row, text) self.draw_label_row(row, text)

View File

@ -1,20 +1,16 @@
from . import any as any_socket from . import any as any_socket
from . import bool as bool_socket from . import bool as bool_socket
from . import expr, file_path, string, data from . import file_path, string
AnySocketDef = any_socket.AnySocketDef AnySocketDef = any_socket.AnySocketDef
DataSocketDef = data.DataSocketDef
BoolSocketDef = bool_socket.BoolSocketDef BoolSocketDef = bool_socket.BoolSocketDef
StringSocketDef = string.StringSocketDef StringSocketDef = string.StringSocketDef
FilePathSocketDef = file_path.FilePathSocketDef FilePathSocketDef = file_path.FilePathSocketDef
ExprSocketDef = expr.ExprSocketDef
BL_REGISTER = [ BL_REGISTER = [
*any_socket.BL_REGISTER, *any_socket.BL_REGISTER,
*data.BL_REGISTER,
*bool_socket.BL_REGISTER, *bool_socket.BL_REGISTER,
*string.BL_REGISTER, *string.BL_REGISTER,
*file_path.BL_REGISTER, *file_path.BL_REGISTER,
*expr.BL_REGISTER,
] ]

View File

@ -70,8 +70,8 @@ class BlenderGeoNodesBLSocket(base.MaxwellSimSocket):
# - Default Value # - Default Value
#################### ####################
@property @property
def value(self) -> bpy.types.NodeTree | None: def value(self) -> bpy.types.NodeTree | ct.FlowSignal:
return self.raw_value return self.raw_value if self.raw_value is not None else ct.FlowSignal.NoFlow
@value.setter @value.setter
def value(self, value: bpy.types.NodeTree) -> None: def value(self, value: bpy.types.NodeTree) -> None:

View File

@ -7,16 +7,14 @@ import sympy as sp
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct from .. import contracts as ct
from .. import base from . import base
## TODO: This is a big node, and there's a lot to get right. ## TODO: This is a big node, and there's a lot to get right.
## - Dynamically adjust the value when the user changes the unit in the UI.
## - Dynamically adjust socket color in response to, especially, the unit dimension. ## - Dynamically adjust socket color in response to, especially, the unit dimension.
## - Iron out the meaning of display shapes. ## - Iron out the meaning of display shapes.
## - Generally pay attention to validity checking; it's make or break. ## - Generally pay attention to validity checking; it's make or break.
## - For array generation, it may pay to have both a symbolic expression (producing output according to `size` as usual) denoting how to actually make values, and how many. Enables ex. easy symbolic ## - For array generation, it may pay to have both a symbolic expression (producing output according to `shape` as usual) denoting how to actually make values, and how many. Enables ex. easy symbolic plots.
## - For array generation, it may pay to have both a symbolic expression (producing output according to `size` as usual)
log = logger.get(__name__) log = logger.get(__name__)
@ -69,24 +67,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
#################### ####################
# - Properties # - Properties
#################### ####################
size: typ.Literal[None, 2, 3] = bl_cache.BLField(None, prop_ui=True) shape: tuple[int, ...] | None = bl_cache.BLField(None)
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real, prop_ui=True) mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real, prop_ui=True)
physical_type: spux.PhysicalType | None = bl_cache.BLField(None)
symbols: frozenset[spux.Symbol] = bl_cache.BLField(frozenset()) symbols: frozenset[spux.Symbol] = bl_cache.BLField(frozenset())
## Units
unit_dim: spux.UnitDimension | None = bl_cache.BLField(None)
active_unit: enum.Enum = bl_cache.BLField( active_unit: enum.Enum = bl_cache.BLField(
None, enum_cb=lambda self, _: self.search_units(), prop_ui=True None, enum_cb=lambda self, _: self.search_units(), prop_ui=True
) )
## Info Display
show_info_columns: bool = bl_cache.BLField(False, prop_ui=True)
info_columns: InfoDisplayCol = bl_cache.BLField(
{InfoDisplayCol.MathType, InfoDisplayCol.Unit},
prop_ui=True,
enum_many=True,
)
# UI: Value # UI: Value
## Expression ## Expression
raw_value_spstr: str = bl_cache.BLField('', prop_ui=True) raw_value_spstr: str = bl_cache.BLField('', prop_ui=True)
@ -94,7 +83,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
raw_value_int: int = bl_cache.BLField(0, prop_ui=True) raw_value_int: int = bl_cache.BLField(0, prop_ui=True)
raw_value_rat: Int2 = bl_cache.BLField((0, 1), prop_ui=True) raw_value_rat: Int2 = bl_cache.BLField((0, 1), prop_ui=True)
raw_value_float: float = bl_cache.BLField(0.0, float_prec=4, prop_ui=True) raw_value_float: float = bl_cache.BLField(0.0, float_prec=4, prop_ui=True)
raw_value_complex: Float2 = bl_cache.BLField((0, 1), float_prec=4, prop_ui=True) raw_value_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=4, prop_ui=True)
## 2D ## 2D
raw_value_int2: Int2 = bl_cache.BLField((0, 0), prop_ui=True) raw_value_int2: Int2 = bl_cache.BLField((0, 0), prop_ui=True)
raw_value_rat2: Int22 = bl_cache.BLField(((0, 1), (0, 1)), prop_ui=True) raw_value_rat2: Int22 = bl_cache.BLField(((0, 1), (0, 1)), prop_ui=True)
@ -105,7 +94,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
## 3D ## 3D
raw_value_int3: Int3 = bl_cache.BLField((0, 0, 0), prop_ui=True) raw_value_int3: Int3 = bl_cache.BLField((0, 0, 0), prop_ui=True)
raw_value_rat3: Int32 = bl_cache.BLField(((0, 1), (0, 1), (0, 1)), prop_ui=True) raw_value_rat3: Int32 = bl_cache.BLField(((0, 1), (0, 1), (0, 1)), prop_ui=True)
raw_value_float3: Float3 = bl_cache.BLField((0.0, 0.0), float_prec=4, prop_ui=True) raw_value_float3: Float3 = bl_cache.BLField(
(0.0, 0.0, 0.0), float_prec=4, prop_ui=True
)
raw_value_complex3: Float32 = bl_cache.BLField( raw_value_complex3: Float32 = bl_cache.BLField(
((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4, prop_ui=True ((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4, prop_ui=True
) )
@ -123,6 +114,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
((0.0, 0.0), (1.0, 1.0)), float_prec=4, prop_ui=True ((0.0, 0.0), (1.0, 1.0)), float_prec=4, prop_ui=True
) )
# UI: Info
show_info_columns: bool = bl_cache.BLField(False, prop_ui=True)
info_columns: InfoDisplayCol = bl_cache.BLField(
{InfoDisplayCol.MathType, InfoDisplayCol.Unit},
prop_ui=True,
enum_many=True,
)
#################### ####################
# - Computed: Raw Expressions # - Computed: Raw Expressions
#################### ####################
@ -142,35 +141,63 @@ class ExprBLSocket(base.MaxwellSimSocket):
# - Computed: Units # - Computed: Units
#################### ####################
def search_units(self, _: bpy.types.Context) -> list[ct.BLEnumElement]: def search_units(self, _: bpy.types.Context) -> list[ct.BLEnumElement]:
if self.unit_dim is not None: if self.physical_type is not None:
return [ return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i) (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(spux.unit_dim_units(self.unit_dim)) for i, unit in enumerate(self.physical_type.valid_units)
] ]
return [] return []
@property @bl_cache.cached_bl_property()
def unit(self) -> spux.Unit | None: def unit(self) -> spux.Unit | None:
if self.active_unit != 'NONE': """Gets the current active unit.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_unit is not None:
return spux.unit_str_to_unit(self.active_unit) return spux.unit_str_to_unit(self.active_unit)
return None return None
@unit.setter @unit.setter
def unit(self, unit: spux.Unit) -> None: def unit(self, unit: spux.Unit) -> None:
valid_units = spux.unit_dim_units(self.unit_dim) """Set the unit, without touching the `raw_*` UI properties.
if unit in valid_units:
Notes:
To set a new unit, **and** convert the `raw_*` UI properties to the new unit, use `self.convert_unit()` instead.
"""
if unit in self.physical_type.valid_units:
self.active_unit = sp.sstr(unit) self.active_unit = sp.sstr(unit)
msg = f'Tried to set invalid unit {unit} (unit dim "{self.unit_dim}" only supports "{valid_units}")' msg = f'Tried to set invalid unit {unit} (physical type "{self.physical_type}" only supports "{self.physical_type.valid_units}")'
raise ValueError(msg) raise ValueError(msg)
def convert_unit(self, unit_to: spux.Unit) -> None:
if self.active_kind == ct.FlowKind.Value:
current_value = self.value
self.unit = unit_to
self.value = current_value
elif self.active_kind == ct.FlowKind.LazyArrayRange:
current_lazy_array_range = self.lazy_array_range
self.unit = unit_to
self.lazy_array_range = current_lazy_array_range
####################
# - Property Callback
####################
def on_socket_prop_changed(self, prop_name: str) -> None:
if prop_name == 'unit' and self.active_unit is not None:
self.convert_unit(spux.unit_str_to_unit(self.active_unit))
#################### ####################
# - Methods # - Methods
#################### ####################
def _parse_expr_info( def _parse_expr_info(
self, expr: spux.SympyExpr self, expr: spux.SympyExpr
) -> tuple[spux.MathType, typ.Literal[None, 2, 3], spux.UnitDimension]: ) -> tuple[spux.MathType, tuple[int, ...] | None, spux.UnitDimension]:
# Parse MathType # Parse MathType
mathtype = spux.MathType.from_expr(expr) mathtype = spux.MathType.from_expr(expr)
if self.mathtype != mathtype: if self.mathtype != mathtype:
@ -188,18 +215,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
raise ValueError(msg) raise ValueError(msg)
# Parse Dimensions # Parse Dimensions
size = spux.parse_size(expr) shape = spux.parse_shape(expr)
if size != self.size: if shape != self.shape:
msg = f'Expr {expr} has {size} dimensions, which is incompatible with the expr socket ({self.size} dimensions)' msg = f'Expr {expr} has shape {shape}, which is incompatible with the expr socket (shape {self.shape})'
raise ValueError(msg) raise ValueError(msg)
# Parse Unit Dimension return mathtype, shape
unit_dim = spux.parse_unit_dim(expr)
if unit_dim != self.unit_dim:
msg = f'Expr {expr} has unit dimension {unit_dim}, which is incompatible with socket unit dimension {self.unit_dim}'
raise ValueError(msg)
return mathtype, size, unit_dim
def _to_raw_value(self, expr: spux.SympyExpr): def _to_raw_value(self, expr: spux.SympyExpr):
if self.unit is not None: if self.unit is not None:
@ -212,7 +233,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
locals={sym.name: sym for sym in self.symbols}, locals={sym.name: sym for sym in self.symbols},
strict=False, strict=False,
convert_xor=True, convert_xor=True,
).subs(spux.ALL_UNIT_SYMBOLS) * (self.unit if self.unit is not None else 1) ).subs(spux.UNIT_BY_SYMBOL) * (self.unit if self.unit is not None else 1)
# Try Parsing and Returning the Expression # Try Parsing and Returning the Expression
try: try:
@ -245,7 +266,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
Return: Return:
The expression defined by the socket, in the socket's unit. The expression defined by the socket, in the socket's unit.
""" """
if self.symbols: if self.symbols or self.shape not in [None, (2,), (3,)]:
expr = self.raw_value_sp expr = self.raw_value_sp
if expr is None: if expr is None:
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -266,7 +287,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_value_complex[0] + sp.I * self.raw_value_complex[1] self.raw_value_complex[0] + sp.I * self.raw_value_complex[1]
), ),
}, },
2: { (2,): {
MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int2]), MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int2]),
MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat2]), MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat2]),
MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float2]), MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float2]),
@ -274,7 +295,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
[c[0] + sp.I * c[1] for c in self.raw_value_complex2] [c[0] + sp.I * c[1] for c in self.raw_value_complex2]
), ),
}, },
3: { (3,): {
MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int3]), MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int3]),
MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat3]), MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat3]),
MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float3]), MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float3]),
@ -282,7 +303,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
[c[0] + sp.I * c[1] for c in self.raw_value_complex3] [c[0] + sp.I * c[1] for c in self.raw_value_complex3]
), ),
}, },
}[self.size][self.mathtype]() * (self.unit if self.unit is not None else 1) }[self.shape][self.mathtype]() * (self.unit if self.unit is not None else 1)
@value.setter @value.setter
def value(self, expr: spux.SympyExpr) -> None: def value(self, expr: spux.SympyExpr) -> None:
@ -291,8 +312,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes: Notes:
Called to set the internal `FlowKind.Value` of this socket. Called to set the internal `FlowKind.Value` of this socket.
""" """
mathtype, size, unit_dim = self._parse_expr_info(expr) mathtype, shape = self._parse_expr_info(expr)
if self.symbols: if self.symbols or self.shape not in [None, (2,), (3,)]:
self.raw_value_spstr = sp.sstr(expr) self.raw_value_spstr = sp.sstr(expr)
else: else:
@ -300,7 +321,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q = spux.MathType.Rational MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex MT_C = spux.MathType.Complex
if size is None: if shape is None:
if mathtype == MT_Z: if mathtype == MT_Z:
self.raw_value_int = self._to_raw_value(expr) self.raw_value_int = self._to_raw_value(expr)
elif mathtype == MT_Q: elif mathtype == MT_Q:
@ -309,7 +330,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_value_float = self._to_raw_value(expr) self.raw_value_float = self._to_raw_value(expr)
elif mathtype == MT_C: elif mathtype == MT_C:
self.raw_value_complex = self._to_raw_value(expr) self.raw_value_complex = self._to_raw_value(expr)
elif size == 2: elif shape == (2,):
if mathtype == MT_Z: if mathtype == MT_Z:
self.raw_value_int2 = self._to_raw_value(expr) self.raw_value_int2 = self._to_raw_value(expr)
elif mathtype == MT_Q: elif mathtype == MT_Q:
@ -318,7 +339,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_value_float2 = self._to_raw_value(expr) self.raw_value_float2 = self._to_raw_value(expr)
elif mathtype == MT_C: elif mathtype == MT_C:
self.raw_value_complex2 = self._to_raw_value(expr) self.raw_value_complex2 = self._to_raw_value(expr)
elif size == 3: elif shape == (3,):
if mathtype == MT_Z: if mathtype == MT_Z:
self.raw_value_int3 = self._to_raw_value(expr) self.raw_value_int3 = self._to_raw_value(expr)
elif mathtype == MT_Q: elif mathtype == MT_Q:
@ -341,9 +362,35 @@ class ExprBLSocket(base.MaxwellSimSocket):
Return: Return:
The range of lengths, which uses no symbols. The range of lengths, which uses no symbols.
""" """
if self.symbols:
return ct.LazyArrayRangeFlow( return ct.LazyArrayRangeFlow(
start=sp.S(self.min_value) * self.unit, start=self.raw_min_sp,
stop=sp.S(self.max_value) * self.unit, stop=self.raw_max_sp,
steps=self.steps,
scaling='lin',
unit=self.unit,
symbols=self.symbols,
)
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
Z = sp.Integer
Q = sp.Rational
R = sp.RealNumber
min_bound, max_bound = {
MT_Z: lambda: [Z(bound) for bound in self.raw_range_int],
MT_Q: lambda: [Q(bound[0], bound[1]) for bound in self.raw_range_rat],
MT_R: lambda: [R(bound) for bound in self.raw_range_float],
MT_C: lambda: [
bound[0] + sp.I * bound[1] for bound in self.raw_range_complex
],
}[self.mathtype]()
return ct.LazyArrayRangeFlow(
start=min_bound,
stop=max_bound,
steps=self.steps, steps=self.steps,
scaling='lin', scaling='lin',
unit=self.unit, unit=self.unit,
@ -356,25 +403,74 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes: Notes:
Called to compute the internal `FlowKind.LazyArrayRange` of this socket. Called to compute the internal `FlowKind.LazyArrayRange` of this socket.
""" """
self.min_value = spux.sympy_to_python(
spux.scale_to_unit(value.start * value.unit, self.unit)
)
self.max_value = spux.sympy_to_python(
spux.scale_to_unit(value.stop * value.unit, self.unit)
)
self.steps = value.steps self.steps = value.steps
self.unit = value.unit
if self.symbols:
self.raw_min_spstr = sp.sstr(value.start)
self.raw_max_spstr = sp.sstr(value.stop)
else:
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
if value.mathtype == MT_Z:
self.raw_range_int = [
self._to_raw_value(bound) for bound in [value.start, value.stop]
]
elif value.mathtype == MT_Q:
self.raw_range_rat = [
self._to_raw_value(bound) for bound in [value.start, value.stop]
]
elif value.mathtype == MT_R:
self.raw_range_float = [
self._to_raw_value(bound) for bound in [value.start, value.stop]
]
elif value.mathtype == MT_C:
self.raw_range_complex = [
self._to_raw_value(bound) for bound in [value.start, value.stop]
]
#################### ####################
# - FlowKind: LazyValueFunc # - FlowKind: LazyValueFunc (w/Params if Constant)
#################### ####################
@property @property
def lazy_value_func(self) -> ct.LazyValueFuncFlow: def lazy_value_func(self) -> ct.LazyValueFuncFlow:
# Lazy Value: Arbitrary Expression
if self.symbols or self.shape not in [None, (2,), (3,)]:
return ct.LazyValueFuncFlow( return ct.LazyValueFuncFlow(
func=sp.lambdify(self.symbols, self.value, 'jax'), func=sp.lambdify(self.symbols, self.value, 'jax'),
func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols], func_args=[spux.MathType.from_expr(sym) for sym in self.symbols],
supports_jax=True, supports_jax=True,
) )
# Lazy Value: Constant
## -> A very simple function, which takes a single argument.
## -> What will be passed is a unit-scaled/stripped, pytype-converted Expr:Value.
## -> Until then, the user can utilize this LVF in a function composition chain.
return ct.LazyValueFuncFlow(
func=lambda v: v,
func_args=[
self.physical_type if self.physical_type is not None else self.mathtype
],
supports_jax=True,
)
@property
def params(self) -> ct.ParamsFlow:
# Params Value: Symbolic
## -> The Expr socket does not declare actual values for the symbols.
## -> Those values must come from elsewhere.
## -> If someone tries to load them anyway, tell them 'NoFlow'.
if self.symbols or self.shape not in [None, (2,), (3,)]:
return ct.FlowSignal.NoFlow
# Params Value: Constant
## -> Simply pass the Expr:Value as parameter.
return ct.ParamsFlow(func_args=[self.value])
#################### ####################
# - FlowKind: Array # - FlowKind: Array
#################### ####################
@ -396,7 +492,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
def info(self) -> ct.ArrayFlow: def info(self) -> ct.ArrayFlow:
return ct.InfoFlow( return ct.InfoFlow(
output_name='_', ## TODO: Something else output_name='_', ## TODO: Something else
output_shape=(self.size,) if self.size is not None else None, output_shape=self.shape,
output_mathtype=self.mathtype, output_mathtype=self.mathtype,
output_unit=self.unit, output_unit=self.unit,
) )
@ -416,6 +512,16 @@ class ExprBLSocket(base.MaxwellSimSocket):
#################### ####################
# - UI # - UI
#################### ####################
def draw_label_row(self, row: bpy.types.UILayout, text) -> None:
if self.active_unit is not None:
split = row.split(factor=0.6, align=True)
_row = split.row(align=True)
_row.label(text=text)
_col = split.column(align=True)
_col.prop(self, 'active_unit', text='')
def draw_value(self, col: bpy.types.UILayout) -> None: def draw_value(self, col: bpy.types.UILayout) -> None:
# Property Interface # Property Interface
if self.symbols: if self.symbols:
@ -426,7 +532,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q = spux.MathType.Rational MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex MT_C = spux.MathType.Complex
if self.size is None: if self.shape is None:
if self.mathtype == MT_Z: if self.mathtype == MT_Z:
col.prop(self, self.blfields['raw_value_int'], text='') col.prop(self, self.blfields['raw_value_int'], text='')
elif self.mathtype == MT_Q: elif self.mathtype == MT_Q:
@ -435,7 +541,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
col.prop(self, self.blfields['raw_value_float'], text='') col.prop(self, self.blfields['raw_value_float'], text='')
elif self.mathtype == MT_C: elif self.mathtype == MT_C:
col.prop(self, self.blfields['raw_value_complex'], text='') col.prop(self, self.blfields['raw_value_complex'], text='')
elif self.size == 2: elif self.shape == (2,):
if self.mathtype == MT_Z: if self.mathtype == MT_Z:
col.prop(self, self.blfields['raw_value_int2'], text='') col.prop(self, self.blfields['raw_value_int2'], text='')
elif self.mathtype == MT_Q: elif self.mathtype == MT_Q:
@ -444,7 +550,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
col.prop(self, self.blfields['raw_value_float2'], text='') col.prop(self, self.blfields['raw_value_float2'], text='')
elif self.mathtype == MT_C: elif self.mathtype == MT_C:
col.prop(self, self.blfields['raw_value_complex2'], text='') col.prop(self, self.blfields['raw_value_complex2'], text='')
elif self.size == 3: elif self.shape == (3,):
if self.mathtype == MT_Z: if self.mathtype == MT_Z:
col.prop(self, self.blfields['raw_value_int3'], text='') col.prop(self, self.blfields['raw_value_int3'], text='')
elif self.mathtype == MT_Q: elif self.mathtype == MT_Q:
@ -579,34 +685,57 @@ class ExprSocketDef(base.SocketDef):
ct.FlowKind.Value ct.FlowKind.Value
) )
# Properties # Socket Interface
size: typ.Literal[None, 2, 3] = None ## TODO: __hash__ like socket method based on these?
shape: tuple[int, ...] | None = None
mathtype: spux.MathType = spux.MathType.Real mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType | None = None
symbols: frozenset[spux.Symbol] = frozenset() symbols: frozenset[spux.Symbol] = frozenset()
## Units
unit_dim: spux.UnitDimension | None = None # Socket Units
## Info Display default_unit: spux.Unit | None = None
show_info_columns: bool = False
# FlowKind: Value
default_value: spux.SympyExpr = sp.S(0)
# FlowKind: LazyArrayRange
default_min: spux.SympyExpr = sp.S(0)
default_max: spux.SympyExpr = sp.S(1)
default_steps: int = 2
## TODO: Configure lin/log/... scaling (w/enumprop in UI)
## TODO: Buncha validation :) ## TODO: Buncha validation :)
# Defaults # UI
default_unit: spux.Unit | None = None show_info_columns: bool = False
default_value: spux.SympyExpr = sp.S(1)
default_min: spux.SympyExpr = sp.S(0)
default_max: spux.SympyExpr = sp.S(1)
default_steps: spux.SympyExpr = sp.S(2)
def init(self, bl_socket: ExprBLSocket) -> None: def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind bl_socket.active_kind = self.active_kind
bl_socket.size = self.size
bl_socket.mathtype = self.size
bl_socket.symbols = self.symbols
bl_socket.unit_dim = self.size
bl_socket.unit = self.symbols
bl_socket.show_info_columns = self.show_info_columns
bl_socket.value = self.default # Socket Interface
bl_socket.shape = self.shape
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
bl_socket.symbols = self.symbols
# Socket Units
if self.default_unit is not None:
bl_socket.unit = self.default_unit
# FlowKind: Value
bl_socket.value = self.default_value
# FlowKind: LazyArrayRange
bl_socket.lazy_array_range = ct.LazyArrayRangeFlow(
start=self.default_min,
stop=self.default_max,
steps=self.default_steps,
scaling='lin',
unit=self.default_unit,
)
# UI
bl_socket.show_info_columns = self.show_info_columns
#################### ####################

View File

@ -1,49 +1,51 @@
from . import bound_cond, bound_conds from . import (
bound_cond,
bound_conds,
fdtd_sim,
fdtd_sim_data,
medium,
medium_non_linearity,
monitor,
monitor_data,
sim_domain,
sim_grid,
sim_grid_axis,
source,
structure,
temporal_shape,
)
MaxwellBoundCondSocketDef = bound_cond.MaxwellBoundCondSocketDef MaxwellBoundCondSocketDef = bound_cond.MaxwellBoundCondSocketDef
MaxwellBoundCondsSocketDef = bound_conds.MaxwellBoundCondsSocketDef MaxwellBoundCondsSocketDef = bound_conds.MaxwellBoundCondsSocketDef
MaxwellFDTDSimSocketDef = fdtd_sim.MaxwellFDTDSimSocketDef
from . import medium, medium_non_linearity MaxwellFDTDSimDataSocketDef = fdtd_sim_data.MaxwellFDTDSimDataSocketDef
MaxwellMediumSocketDef = medium.MaxwellMediumSocketDef MaxwellMediumSocketDef = medium.MaxwellMediumSocketDef
MaxwellMediumNonLinearitySocketDef = ( MaxwellMediumNonLinearitySocketDef = (
medium_non_linearity.MaxwellMediumNonLinearitySocketDef medium_non_linearity.MaxwellMediumNonLinearitySocketDef
) )
from . import source, temporal_shape
MaxwellSourceSocketDef = source.MaxwellSourceSocketDef
MaxwellTemporalShapeSocketDef = temporal_shape.MaxwellTemporalShapeSocketDef
from . import structure
MaxwellStructureSocketDef = structure.MaxwellStructureSocketDef
from . import monitor
MaxwellMonitorSocketDef = monitor.MaxwellMonitorSocketDef MaxwellMonitorSocketDef = monitor.MaxwellMonitorSocketDef
MaxwellMonitorDataSocketDef = monitor_data.MaxwellMonitorDataSocketDef
from . import fdtd_sim, fdtd_sim_data, sim_domain, sim_grid, sim_grid_axis MaxwellSimDomainSocketDef = sim_domain.MaxwellSimDomainSocketDef
MaxwellFDTDSimSocketDef = fdtd_sim.MaxwellFDTDSimSocketDef
MaxwellFDTDSimDataSocketDef = fdtd_sim_data.MaxwellFDTDSimDataSocketDef
MaxwellSimGridSocketDef = sim_grid.MaxwellSimGridSocketDef MaxwellSimGridSocketDef = sim_grid.MaxwellSimGridSocketDef
MaxwellSimGridAxisSocketDef = sim_grid_axis.MaxwellSimGridAxisSocketDef MaxwellSimGridAxisSocketDef = sim_grid_axis.MaxwellSimGridAxisSocketDef
MaxwellSimDomainSocketDef = sim_domain.MaxwellSimDomainSocketDef MaxwellSourceSocketDef = source.MaxwellSourceSocketDef
MaxwellStructureSocketDef = structure.MaxwellStructureSocketDef
MaxwellTemporalShapeSocketDef = temporal_shape.MaxwellTemporalShapeSocketDef
BL_REGISTER = [ BL_REGISTER = [
*bound_cond.BL_REGISTER, *bound_cond.BL_REGISTER,
*bound_conds.BL_REGISTER, *bound_conds.BL_REGISTER,
*medium.BL_REGISTER,
*medium_non_linearity.BL_REGISTER,
*source.BL_REGISTER,
*temporal_shape.BL_REGISTER,
*structure.BL_REGISTER,
*monitor.BL_REGISTER,
*fdtd_sim.BL_REGISTER, *fdtd_sim.BL_REGISTER,
*fdtd_sim_data.BL_REGISTER, *fdtd_sim_data.BL_REGISTER,
*medium.BL_REGISTER,
*medium_non_linearity.BL_REGISTER,
*monitor.BL_REGISTER,
*monitor_data.BL_REGISTER,
*sim_domain.BL_REGISTER,
*sim_grid.BL_REGISTER, *sim_grid.BL_REGISTER,
*sim_grid_axis.BL_REGISTER, *sim_grid_axis.BL_REGISTER,
*sim_domain.BL_REGISTER, *source.BL_REGISTER,
*structure.BL_REGISTER,
*temporal_shape.BL_REGISTER,
] ]

View File

@ -0,0 +1,25 @@
from ... import contracts as ct
from .. import base
class MaxwellMonitorDataBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.MaxwellMonitorData
bl_label = 'Maxwell Monitor Data'
####################
# - Socket Configuration
####################
class MaxwellMonitorDataSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellMonitorData
def init(self, bl_socket: MaxwellMonitorDataBLSocket) -> None:
pass
####################
# - Blender Registration
####################
BL_REGISTER = [
MaxwellMonitorDataBLSocket,
]

View File

@ -1,9 +1,9 @@
from . import pol, unit_system from . import pol # , unit_system
PhysicalPolSocketDef = pol.PhysicalPolSocketDef PhysicalPolSocketDef = pol.PhysicalPolSocketDef
BL_REGISTER = [ BL_REGISTER = [
*unit_system.BL_REGISTER, # *unit_system.BL_REGISTER,
*pol.BL_REGISTER, *pol.BL_REGISTER,
] ]

View File

@ -3,12 +3,12 @@ import sympy as sp
import sympy.physics.optics.polarization as spo_pol import sympy.physics.optics.polarization as spo_pol
import sympy.physics.units as spu import sympy.physics.units as spu
from blender_maxwell.utils.pydantic_sympy import SympyExpr from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct from ... import contracts as ct
from .. import base from .. import base
StokesVector = SympyExpr StokesVector = spux.SympyExpr
class PhysicalPolBLSocket(base.MaxwellSimSocket): class PhysicalPolBLSocket(base.MaxwellSimSocket):

View File

@ -1,6 +1,6 @@
import bpy import bpy
from blender_maxwell.utils.pydantic_sympy import SympyExpr #from blender_maxwell.utils.pydantic_sympy import SympyExpr
from ... import contracts as ct from ... import contracts as ct
from .. import base from .. import base

View File

@ -1,16 +1,22 @@
from ..nodeps.utils import blender_type_enum, pydeps from ..nodeps.utils import blender_type_enum, pydeps
from . import ( from . import (
analyze_geonodes, bl_cache,
extra_sympy_units, extra_sympy_units,
image_ops,
logger, logger,
pydantic_sympy, sci_constants,
serialize,
staticproperty,
) )
__all__ = [ __all__ = [
'pydeps',
'analyze_geonodes',
'blender_type_enum', 'blender_type_enum',
'pydeps',
'bl_cache',
'extra_sympy_units', 'extra_sympy_units',
'image_ops',
'logger', 'logger',
'pydantic_sympy', 'sci_constants',
'serialize',
'staticproperty',
] ]

View File

@ -1,30 +0,0 @@
import typing as typ
import bpy
INVALID_BL_SOCKET_TYPES = {
'NodeSocketGeometry',
}
def interface(
geonodes: bpy.types.GeometryNodeTree, ## TODO: bpy type
direc: typ.Literal['INPUT', 'OUTPUT'],
):
"""Returns 'valid' GeoNodes interface sockets.
- The Blender socket type is not something invalid (ex. "Geometry").
- The socket has a default value.
- The socket's direction (input/output) matches the requested direction.
"""
return {
interface_item_name: bl_interface_socket
for interface_item_name, bl_interface_socket in (
geonodes.interface.items_tree.items()
)
if (
bl_interface_socket.socket_type not in INVALID_BL_SOCKET_TYPES
and hasattr(bl_interface_socket, 'default_value')
and bl_interface_socket.in_out == direc
)
}

View File

@ -10,6 +10,7 @@ import uuid
from pathlib import Path from pathlib import Path
import bpy import bpy
import numpy as np
from blender_maxwell import contracts as ct from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger, serialize from blender_maxwell.utils import logger, serialize
@ -515,6 +516,7 @@ class BLField:
use_prop_update: Configures the BLField to run `bl_instance.on_prop_changed(attr_name)` whenever value is set. use_prop_update: Configures the BLField to run `bl_instance.on_prop_changed(attr_name)` whenever value is set.
This is done by setting the `update` method. This is done by setting the `update` method.
enum_cb: Method used to generate new enum elements whenever `Signal.ResetEnum` is presented. enum_cb: Method used to generate new enum elements whenever `Signal.ResetEnum` is presented.
matrix_rowmajor: Blender's UI stores matrices flattened,
""" """
log.debug( log.debug(
@ -528,8 +530,8 @@ class BLField:
## Static ## Static
self._prop_ui = prop_ui self._prop_ui = prop_ui
self._prop_flags = prop_flags self._prop_flags = prop_flags
self._min = abs_min self._abs_min = abs_min
self._max = abs_max self._abs_max = abs_max
self._soft_min = soft_min self._soft_min = soft_min
self._soft_max = soft_max self._soft_max = soft_max
self._float_step = float_step self._float_step = float_step
@ -545,6 +547,12 @@ class BLField:
self._str_cb = str_cb self._str_cb = str_cb
self._enum_cb = enum_cb self._enum_cb = enum_cb
## Vector/Matrix Identity
## -> Matrix Shape assists in the workaround for Matrix Display Bug
self._is_vector = False
self._is_matrix = False
self._matrix_shape = None
## HUGE TODO: Persist these ## HUGE TODO: Persist these
self._str_cb_cache = {} self._str_cb_cache = {}
self._enum_cb_cache = {} self._enum_cb_cache = {}
@ -637,6 +645,7 @@ class BLField:
## Reusable Snippets ## Reusable Snippets
def _add_min_max_kwargs(): def _add_min_max_kwargs():
nonlocal kwargs_prop ## I've heard legends of needing this!
kwargs_prop |= {'min': self._abs_min} if self._abs_min is not None else {} kwargs_prop |= {'min': self._abs_min} if self._abs_min is not None else {}
kwargs_prop |= {'max': self._abs_max} if self._abs_max is not None else {} kwargs_prop |= {'max': self._abs_max} if self._abs_max is not None else {}
kwargs_prop |= ( kwargs_prop |= (
@ -647,6 +656,7 @@ class BLField:
) )
def _add_float_kwargs(): def _add_float_kwargs():
nonlocal kwargs_prop
kwargs_prop |= ( kwargs_prop |= (
{'step': self._float_step} if self._float_step is not None else {} {'step': self._float_step} if self._float_step is not None else {}
) )
@ -684,6 +694,7 @@ class BLField:
default_value = self._default_value default_value = self._default_value
BLProp = bpy.props.BoolVectorProperty BLProp = bpy.props.BoolVectorProperty
kwargs_prop |= {'size': len(typ.get_args(AttrType))} kwargs_prop |= {'size': len(typ.get_args(AttrType))}
self._is_vector = True
## Vector Int ## Vector Int
elif typ.get_origin(AttrType) is tuple and all( elif typ.get_origin(AttrType) is tuple and all(
@ -693,6 +704,7 @@ class BLField:
BLProp = bpy.props.IntVectorProperty BLProp = bpy.props.IntVectorProperty
_add_min_max_kwargs() _add_min_max_kwargs()
kwargs_prop |= {'size': len(typ.get_args(AttrType))} kwargs_prop |= {'size': len(typ.get_args(AttrType))}
self._is_vector = True
## Vector Float ## Vector Float
elif typ.get_origin(AttrType) is tuple and all( elif typ.get_origin(AttrType) is tuple and all(
@ -703,6 +715,59 @@ class BLField:
_add_min_max_kwargs() _add_min_max_kwargs()
_add_float_kwargs() _add_float_kwargs()
kwargs_prop |= {'size': len(typ.get_args(AttrType))} kwargs_prop |= {'size': len(typ.get_args(AttrType))}
self._is_vector = True
## Matrix Bool
elif typ.get_origin(AttrType) is tuple and all(
all(V is bool for V in typ.get_args(T)) for T in typ.get_args(AttrType)
):
# Workaround for Matrix Display Bug
## - Also requires __get__ support to read consistently.
rows = len(typ.get_args(AttrType))
cols = len(typ.get_args(typ.get_args(AttrType)[0]))
default_value = (
np.array(self._default_value, dtype=bool)
.flatten()
.reshape([cols, rows])
).tolist()
BLProp = bpy.props.BoolVectorProperty
kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'}
## 'size' has column-major ordering (Matrix Display Bug).
self._is_matrix = True
self._matrix_shape = (rows, cols)
## Matrix Int
elif typ.get_origin(AttrType) is tuple and all(
all(V is int for V in typ.get_args(T)) for T in typ.get_args(AttrType)
):
_add_min_max_kwargs()
rows = len(typ.get_args(AttrType))
cols = len(typ.get_args(typ.get_args(AttrType)[0]))
default_value = (
np.array(self._default_value, dtype=int).flatten().reshape([cols, rows])
).tolist()
BLProp = bpy.props.IntVectorProperty
kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'}
self._is_matrix = True
self._matrix_shape = (rows, cols)
## Matrix Float
elif typ.get_origin(AttrType) is tuple and all(
all(V is float for V in typ.get_args(T)) for T in typ.get_args(AttrType)
):
_add_min_max_kwargs()
_add_float_kwargs()
rows = len(typ.get_args(AttrType))
cols = len(typ.get_args(typ.get_args(AttrType)[0]))
default_value = (
np.array(self._default_value, dtype=float)
.flatten()
.reshape([cols, rows])
).tolist()
BLProp = bpy.props.FloatVectorProperty
kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'}
self._is_matrix = True
self._matrix_shape = (rows, cols)
## Generic String ## Generic String
elif AttrType is str: elif AttrType is str:
@ -732,7 +797,7 @@ class BLField:
} }
## StrEnum ## StrEnum
elif issubclass(AttrType, enum.StrEnum): elif inspect.isclass(AttrType) and issubclass(AttrType, enum.StrEnum):
default_value = self._default_value default_value = self._default_value
BLProp = bpy.props.EnumProperty BLProp = bpy.props.EnumProperty
kwargs_prop |= { kwargs_prop |= {
@ -792,6 +857,7 @@ class BLField:
) ## TODO: Mine description from owner class __doc__ ) ## TODO: Mine description from owner class __doc__
# Define Property Getter # Define Property Getter
## Serialized properties need to deserialize in the getter.
if prop_is_serialized: if prop_is_serialized:
def getter(_self: BLInstance) -> AttrType: def getter(_self: BLInstance) -> AttrType:
@ -802,6 +868,7 @@ class BLField:
return getattr(_self, bl_attr_name) return getattr(_self, bl_attr_name)
# Define Property Setter # Define Property Setter
## Serialized properties need to serialize in the setter.
if prop_is_serialized: if prop_is_serialized:
def setter(_self: BLInstance, value: AttrType) -> None: def setter(_self: BLInstance, value: AttrType) -> None:
@ -821,7 +888,40 @@ class BLField:
def __get__( def __get__(
self, bl_instance: BLInstance | None, owner: type[BLInstance] self, bl_instance: BLInstance | None, owner: type[BLInstance]
) -> typ.Any: ) -> typ.Any:
return self._cached_bl_property.__get__(bl_instance, owner) value = self._cached_bl_property.__get__(bl_instance, owner)
# enum.Enum: Cast Auto-Injected Dynamic Enum 'NONE' -> None
## As far a Blender is concerned, dynamic enum props can't be empty.
## -> Well, they can... But bad things happen. So they can't.
## So in the interest of the user's sanity, we always ensure one entry.
## -> This one entry always has the one, same, id: 'NONE'.
## Of course, we often want to check for this "there was nothing" case.
## -> Aka, we want to do a `None` check, semantically speaking.
## -> ...But because it's a special thingy, we must check 'NONE'?
## Nonsense. Let the user just check `None`, as Guido intended.
if self._enum_cb is not None and value == 'NONE':
## TODO: Perhaps check if the unsafe callback was actually [].
## -> In case the user themselves want to return 'NONE'.
## -> Why would they do this? Because they are users!
return None
# Sized Vectors/Matrices
## Why not just yeet back a np.array?
## -> Type-annotating a shaped numpy array is... "rough".
## -> Type-annotation tuple[] of known shape is super easy.
## -> Even list[] won't do; its size varies, after all!
## -> Reject modernity. Return to tuple[].
if self._is_vector:
## -> tuple()ify the np.array to respect tuple[] type annotation.
return tuple(np.array(value))
if self._is_matrix:
# Matrix Display Bug: Correctly Read Row-Major Values w/Reshape
return tuple(
map(tuple, np.array(value).flatten().reshape(self._matrix_shape))
)
return value
def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None: def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None:
if value == Signal.ResetEnumItems: if value == Signal.ResetEnumItems:

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +0,0 @@
import dataclasses
import typing as typ
from types import MappingProxyType
import jax
import jax.numpy as jnp
# import jaxtyping as jtyp
import sympy.physics.units as spu
import xarray
from . import logger
log = logger.get(__name__)
DimName: typ.TypeAlias = str
Number: typ.TypeAlias = int | float | complex
NumberRange: typ.TypeAlias = jax.Array
@dataclasses.dataclass(kw_only=True)
class JArray:
"""Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds."""
array: jax.Array
dims: dict[DimName, NumberRange]
dim_units: dict[DimName, spu.Quantity]
####################
# - Constructor
####################
@classmethod
def from_xarray(
cls,
xarr: xarray.DataArray,
dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}),
sort_axis: int = -1,
) -> typ.Self:
return cls(
array=jnp.sort(jnp.array(xarr.data), axis=sort_axis),
dims={
dim_name: jnp.array(xarr.get_index(dim_name).values)
for dim_name in xarr.dims
},
dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims},
)
def idx(self, dim_name: DimName, dim_value: Number) -> int:
found_idx = jnp.searchsorted(self.dims[dim_name], dim_value)
if found_idx == 0:
return found_idx
if found_idx == len(self.dims[dim_name]):
return found_idx - 1
left = self.dims[dim_name][found_idx - 1]
right = self.dims[dim_name][found_idx - 1]
return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx
@property
def dtype(self) -> jnp.dtype:
return self.array.dtype

View File

@ -1,164 +0,0 @@
import typing as typ
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx
from pydantic_core import core_schema as pyd_core_schema
from . import extra_sympy_units as spux
####################
# - Missing Basics
####################
AllowedSympyExprs = sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix
Complex = typx.Annotated[
complex,
pyd.GetPydanticSchema(
lambda tp, handler: pyd_core_schema.no_info_after_validator_function(
lambda x: x, handler(tp)
)
),
]
####################
# - Custom Pydantic Type for sp.Expr
####################
class _SympyExpr:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: AllowedSympyExprs,
_handler: pyd.GetCoreSchemaHandler,
) -> pyd_core_schema.CoreSchema:
def validate_from_str(value: str) -> AllowedSympyExprs:
if not isinstance(value, str):
return value
try:
expr = sp.sympify(value)
except ValueError as ex:
msg = f'Value {value} is not a `sympify`able string'
raise ValueError(msg) from ex
return expr.subs(spux.ALL_UNIT_SYMBOLS)
def validate_from_expr(value: AllowedSympyExprs) -> AllowedSympyExprs:
if not (isinstance(value, sp.Expr | sp.MatrixBase)):
msg = f'Value {value} is not a `sympy` expression'
raise ValueError(msg)
return value
sympy_expr_schema = pyd_core_schema.chain_schema(
[
pyd_core_schema.no_info_plain_validator_function(validate_from_str),
pyd_core_schema.no_info_plain_validator_function(validate_from_expr),
pyd_core_schema.is_instance_schema(AllowedSympyExprs),
]
)
return pyd_core_schema.json_or_python_schema(
json_schema=sympy_expr_schema,
python_schema=sympy_expr_schema,
serialization=pyd_core_schema.plain_serializer_function_ser_schema(
lambda instance: sp.srepr(instance)
),
)
####################
# - Configurable Expression Validation
####################
SympyExpr = typx.Annotated[
AllowedSympyExprs,
_SympyExpr,
]
def ConstrSympyExpr(
# Feature Class
allow_variables: bool = True,
allow_units: bool = True,
# Structure Class
allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
| None = None,
allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
# Element Class
allowed_symbols: set[sp.Symbol] | None = None,
allowed_units: set[spu.Quantity] | None = None,
# Shape Class
allowed_matrix_shapes: set[tuple[int, int]] | None = None,
):
## See `sympy` predicates:
## - <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
def validate_expr(expr: AllowedSympyExprs):
if not (isinstance(expr, sp.Expr | sp.MatrixBase),):
## NOTE: Must match AllowedSympyExprs union elements.
msg = f"expr '{expr}' is not an allowed Sympy expression ({AllowedSympyExprs})"
raise ValueError(msg)
msgs = set()
# Validate Feature Class
if (not allow_variables) and (len(expr.free_symbols) > 0):
msgs.add(
f'allow_variables={allow_variables} does not match expression {expr}.'
)
if (not allow_units) and spux.uses_units(expr):
msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
# Validate Structure Class
if (
allowed_sets
and isinstance(expr, sp.Expr)
and not any(
{
'integer': expr.is_integer,
'rational': expr.is_rational,
'real': expr.is_real,
'complex': expr.is_complex,
}[allowed_set]
for allowed_set in allowed_sets
)
):
msgs.add(
f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
if allowed_structures and not any(
{
'matrix': isinstance(expr, sp.MatrixBase),
}[allowed_set]
for allowed_set in allowed_structures
if allowed_structures != 'scalar'
):
msgs.add(
f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
# Validate Element Class
if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
msgs.add(
f'allowed_symbols={allowed_symbols} does not match expression {expr}'
)
if allowed_units and spux.get_units(expr).issubset(allowed_units):
msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
# Validate Shape Class
if (
allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
) and expr.shape not in allowed_matrix_shapes:
msgs.add(
f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
)
# Error or Return
if msgs:
raise ValueError(str(msgs))
return expr
return typx.Annotated[
AllowedSympyExprs,
_SympyExpr,
pyd.AfterValidator(validate_expr),
]

View File

@ -141,7 +141,7 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
is_representation(obj) and obj[0] == TypeID.SympyType is_representation(obj) and obj[0] == TypeID.SympyType
): ):
obj_value = obj[2] obj_value = obj[2]
return sp.sympify(obj_value).subs(spux.ALL_UNIT_SYMBOLS) return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
if hasattr(_type, 'parse_as_msgspec') and ( if hasattr(_type, 'parse_as_msgspec') and (
is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj] is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj]