diff --git a/TODO.md b/TODO.md index fe5b41c..f5c7319 100644 --- a/TODO.md +++ b/TODO.md @@ -527,9 +527,17 @@ Reported: - (SOLVED) Unreported: +- Units are unruly, and are entirely useless when it comes to going small like this. - The `__mp_main__` bug. - Animated properties within custom node trees don't update with the frame. See: -- Can't update `items` using `id_propertie_ui` of `EnumProperty` +- Can't update `items` using `id_properties_ui` of `EnumProperty`. Maybe less a bug than an annoyance. +- **Matrix Display Bug**: The data given to matrix properties is entirely ignored in the UI; the data is flattened, then left-to-right, up-to-down, the data is inserted. It's neither row-major nor column-major - it's completely flat. + - Though, if one wanted row-major (**as is consistent with `mathutils.Matrix`**), one would be disappointed - the UI prints the matrix property column-major + - Trying to set the matrix property with a `mathutils.Matrix` is even stranger - firstly, the size of the `mathutils.Matrix` must be transposed with respect to the property size (again the col/row major mismatch). But secondly, even when accounting for the col/row major mismatch, the values of a ex. 2x3 (row-major) matrix (written to with a 3x2 matrix with same flattened sequence) is written in a very strange order: + - Write `mathutils.Matrix` `[[0,1], [2,3], [4,10]]`: Results in (UI displayed row-major) `[[0,3], [4,1], [3,5]]` + - **Workaround (write)**: Simply flatten the 2D array, re-shape by `[cols,rows]`. The UI will display as the original array. `myarray.flatten().reshape([cols,rows])`. + - **Workaround (read)**: `np.array([[el1 for el1 in el0] for el0 in BLENDER_OBJ.matrix_prop]).flatten().reshape([rows,cols])`. Simply flatten the property read 2D array and re-shape by `[rows,cols]`. Mind that data type out is equal to data type in. + - Also, for bool matrices, `toggle=True` has no effect. `alignment='CENTER'` also doesn't align the checkboxes in their cells. ## Tidy3D bugs Unreported: diff --git a/doc/_quarto.yml b/doc/_quarto.yml index b776d24..b363b52 100644 --- a/doc/_quarto.yml +++ b/doc/_quarto.yml @@ -118,11 +118,9 @@ quartodoc: - subtitle: "`bl_maxwell.utils`" desc: Utilities wo/shared global state. contents: - - utils.analyze_geonodes - utils.blender_type_enum - utils.extra_sympy_units - utils.logger - - utils.pydantic_sympy - subtitle: "`bl_maxwell.services`" desc: Utilities w/shared global state. @@ -172,7 +170,6 @@ quartodoc: - socket_colors - bl_socket_types - bl_socket_desc_map - - socket_units - unit_systems diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_socket_map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_socket_map.py index 524004b..ef0669b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_socket_map.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_socket_map.py @@ -5,11 +5,9 @@ Attributes: BL_SOCKET_4D_TYPE_PREFIXES: Blender socket prefixes which indicate that the Blender socket has four values. """ -import functools import typing as typ import bpy -import sympy as sp from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger as _logger @@ -19,238 +17,54 @@ from . import sockets log = _logger.get(__name__) -BLSocketType: typ.TypeAlias = str ## A Blender-Defined Socket Type -BLSocketValue: typ.TypeAlias = typ.Any ## A Blender Socket Value -BLSocketSize: typ.TypeAlias = int -DescType: typ.TypeAlias = str -Unit: typ.TypeAlias = typ.Any ## Type of a valid unit -## TODO: Move this kind of thing to contracts - #################### -# - BL Socket Size Parser +# - Blender -> Socket Def(s) #################### -BL_SOCKET_3D_TYPE_PREFIXES = { - 'NodeSocketVector', - 'NodeSocketRotation', -} -BL_SOCKET_4D_TYPE_PREFIXES = { - 'NodeSocketColor', -} +def socket_def_from_bl_isocket( + bl_isocket: bpy.types.NodeTreeInterfaceSocket, +) -> sockets.base.SocketDef | None: + """Deduces and constructs an appropriate SocketDef to match the given `bl_interface_socket`.""" + blsck_info = ct.BLSocketType.info_from_bl_isocket(bl_isocket) + if blsck_info.has_support and not blsck_info.is_preview: + # Map Expr Socket + ## -> Accounts for any combo of shape/MathType/PhysicalType. + if blsck_info.socket_type == ct.SocketType.Expr: + return sockets.ExprSocketDef( + shape=blsck_info.size.shape, + mathtype=blsck_info.mathtype, + physical_type=blsck_info.physical_type, + default_unit=ct.UNITS_BLENDER[blsck_info.physical_type], + default_value=blsck_info.default_value, + ) + + ## TODO: Explicitly map default to other supported SocketDef constructors + + return sockets.SOCKET_DEFS[blsck_info.socket_type]() + return None -@functools.lru_cache(maxsize=4096) -def _size_from_bl_socket( - description: str, - bl_socket_type: BLSocketType, -): - """Parses the number of elements contained in a Blender interface socket. - - Since there are no 2D sockets in Blender, the user can specify "2D" in the Blender socket's description to "promise" that only the first two values will be used. - When this is done, the third value is just never altered by the addon. - - A hard-coded set of NodeSocket prefixes are used to determine which interface sockets are, in fact, 3D. - - For 3D sockets, a hard-coded list of Blender node socket types is used. - - Else, it is a 1D socket type. - """ - if description.startswith('2D'): - return 2 - if any( - bl_socket_type.startswith(bl_socket_3d_type_prefix) - for bl_socket_3d_type_prefix in BL_SOCKET_3D_TYPE_PREFIXES - ): - return 3 - if any( - bl_socket_type.startswith(bl_socket_4d_type_prefix) - for bl_socket_4d_type_prefix in BL_SOCKET_4D_TYPE_PREFIXES - ): - return 4 - - return 1 +def sockets_from_geonodes( + geonodes: bpy.types.GeometryNodeTree, +) -> dict[ct.SocketName, sockets.base.SocketDef]: + """Deduces and constructs appropriate SocketDefs to match all input sockets to the given GeoNodes tree.""" + raw_socket_defs = { + socket_name: socket_def_from_bl_isocket(bl_isocket) + for socket_name, bl_isocket in geonodes.interface.items_tree.items() + } + return { + socket_name: socket_def + for socket_name, socket_def in raw_socket_defs.items() + if socket_def is not None + } -#################### -# - BL Socket Type / Unit Parser -#################### -@functools.lru_cache(maxsize=4096) -def _socket_type_from_bl_socket( - description: str, - bl_socket_type: BLSocketType, -) -> ct.SocketType: - """Parse a Blender socket for a matching BLMaxwell socket type, relying on both the Blender socket type and user-generated hints in the description. - - Arguments: - description: The description from Blender socket, aka. `bl_socket.description`. - bl_socket_type: The Blender socket type, aka. `bl_socket.socket_type`. - - Returns: - The type of a MaxwellSimSocket that corresponds to the Blender socket. - """ - size = _size_from_bl_socket(description, bl_socket_type) - - # Determine Socket Type Directly - ## The naive mapping from BL socket -> Maxwell socket may be good enough. - if ( - direct_socket_type := ct.BL_SOCKET_DIRECT_TYPE_MAP.get((bl_socket_type, size)) - ) is None: - msg = "Blender interface socket has no mapping among 'MaxwellSimSocket's." - raise ValueError(msg) - - # (No Description) Return Direct Socket Type - if ct.BL_SOCKET_DESCR_ANNOT_STRING not in description: - return direct_socket_type - - # Parse Description for Socket Type - ## The "2D" token is special; don't include it if it's there. - descr_params = description.split(ct.BL_SOCKET_DESCR_ANNOT_STRING)[0] - directive = ( - _tokens[0] if (_tokens := descr_params.split(' '))[0] != '2D' else _tokens[1] - ) - if directive == 'Preview': - return direct_socket_type ## TODO: Preview element handling - - if ( - socket_type := ct.BL_SOCKET_DESCR_TYPE_MAP.get( - (directive, bl_socket_type, size) - ) - ) is None: - msg = f'Socket description "{(directive, bl_socket_type, size)}" doesn\'t map to a socket type + unit' - raise ValueError(msg) - - return socket_type - - -#################### -# - BL Socket Interface Definition -#################### -@functools.lru_cache(maxsize=4096) -def _socket_def_from_bl_socket( - description: str, - bl_socket_type: BLSocketType, -) -> ct.SocketType: - return sockets.SOCKET_DEFS[_socket_type_from_bl_socket(description, bl_socket_type)] - - -def socket_def_from_bl_socket( - bl_interface_socket: bpy.types.NodeTreeInterfaceSocket, -) -> sockets.base.SocketDef: - """Computes an appropriate (no-arg) SocketDef from the given `bl_interface_socket`, by parsing it.""" - return _socket_def_from_bl_socket( - bl_interface_socket.description, bl_interface_socket.bl_socket_idname - ) - - -#################### -# - Extract Default Interface Socket Value -#################### -def _read_bl_socket_default_value( - description: str, - bl_socket_type: BLSocketType, - bl_socket_value: BLSocketValue, - unit_system: dict | None = None, - allow_unit_not_in_unit_system: bool = False, -) -> typ.Any: - # Parse the BL Socket Type and Value - ## The 'lambda' delays construction until size is determined. - socket_type = _socket_type_from_bl_socket(description, bl_socket_type) - parsed_socket_value = { - 1: lambda: bl_socket_value, - 2: lambda: sp.Matrix(tuple(bl_socket_value)[:2]), - 3: lambda: sp.Matrix(tuple(bl_socket_value)), - 4: lambda: sp.Matrix(tuple(bl_socket_value)), - }[_size_from_bl_socket(description, bl_socket_type)]() - - # Add Unit-System Unit to Parsed - ## Use the matching socket type to lookup the unit in the unit system. - if unit_system is not None: - if (unit := unit_system.get(socket_type)) is None: - if allow_unit_not_in_unit_system: - return parsed_socket_value - - msg = f'Unit system does not provide a unit for {socket_type}' - raise RuntimeError(msg) - - return parsed_socket_value * unit - return parsed_socket_value - - -def read_bl_socket_default_value( - bl_interface_socket: bpy.types.NodeTreeInterfaceSocket, - unit_system: dict | None = None, - allow_unit_not_in_unit_system: bool = False, -) -> typ.Any: - """Reads the `default_value` of a Blender socket, guaranteeing a well-formed value consistent with the passed unit system. - - Arguments: - bl_interface_socket: The Blender interface socket to analyze for description, socket type, and default value. - unit_system: The mapping from BLMaxwell SocketType to corresponding unit, used to apply the appropriate unit to the output. - - Returns: - The parsed, well-formed version of `bl_socket.default_value`, of the appropriate form and unit. - - """ - return _read_bl_socket_default_value( - bl_interface_socket.description, - bl_interface_socket.bl_socket_idname, - bl_interface_socket.default_value, - unit_system=unit_system, - allow_unit_not_in_unit_system=allow_unit_not_in_unit_system, - ) - - -def _writable_bl_socket_value( - description: str, - bl_socket_type: BLSocketType, - value: typ.Any, - unit_system: dict | None = None, - allow_unit_not_in_unit_system: bool = False, -) -> typ.Any: - socket_type = _socket_type_from_bl_socket(description, bl_socket_type) - - # Retrieve Unit-System Unit - if unit_system is not None: - if (unit := unit_system.get(socket_type)) is None: - if allow_unit_not_in_unit_system: - _bl_socket_value = value - else: - msg = f'Unit system does not provide a unit for {socket_type}' - raise RuntimeError(msg) - else: - _bl_socket_value = spux.scale_to_unit(value, unit) - else: - _bl_socket_value = value - - # Compute Blender Socket Value - if isinstance(_bl_socket_value, sp.Basic | sp.MatrixBase): - bl_socket_value = spux.sympy_to_python(_bl_socket_value) - else: - bl_socket_value = _bl_socket_value - - if _size_from_bl_socket(description, bl_socket_type) == 2: # noqa: PLR2004 - bl_socket_value = bl_socket_value[:2] - return bl_socket_value - - -def writable_bl_socket_value( - bl_interface_socket: bpy.types.NodeTreeInterfaceSocket, - value: typ.Any, - unit_system: dict | None = None, - allow_unit_not_in_unit_system: bool = False, -) -> typ.Any: - """Processes a value to be ready-to-write to a Blender socket. - - Arguments: - bl_interface_socket: The Blender interface socket to analyze - value: The value to prepare for writing to the given Blender socket. - unit_system: The mapping from BLMaxwell SocketType to corresponding unit, used to scale the value to the the appropriate unit. - - Returns: - A value corresponding to the input, which is guaranteed to be compatible with the Blender socket (incl. via a GeoNodes modifier), as well as correctly scaled with respect to the given unit system. - """ - return _writable_bl_socket_value( - bl_interface_socket.description, - bl_interface_socket.bl_socket_idname, - value, - unit_system=unit_system, - allow_unit_not_in_unit_system=allow_unit_not_in_unit_system, - ) +## TODO: Make it fast, it's in a hot loop... +def info_from_geonodes( + geonodes: bpy.types.GeometryNodeTree, +) -> dict[ct.SocketName, ct.BLSocketInfo]: + """Deduces and constructs appropriate SocketDefs to match all input sockets to the given GeoNodes tree.""" + return { + socket_name: ct.BLSocketType.info_from_bl_isocket(bl_isocket) + for socket_name, bl_isocket in geonodes.interface.items_tree.items() + } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py index eb53078..4121f6f 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py @@ -1,41 +1,40 @@ from blender_maxwell.contracts import ( - BLClass, - BLColorRGBA, - BLEnumElement, - BLEnumID, - BLIcon, - BLIconSet, - BLIDStruct, - BLKeymapItem, - BLModifierType, - BLNodeTreeInterfaceID, - BLOperatorStatus, - BLPropFlag, - BLRegionType, - BLSpaceType, - KeymapItemDef, - ManagedObjName, - OperatorType, - PanelType, - PresetName, - SocketName, - addon, + BLClass, + BLColorRGBA, + BLEnumElement, + BLEnumID, + BLIcon, + BLIconSet, + BLIDStruct, + BLKeymapItem, + BLModifierType, + BLNodeTreeInterfaceID, + BLOperatorStatus, + BLPropFlag, + BLRegionType, + BLSpaceType, + KeymapItemDef, + ManagedObjName, + OperatorType, + PanelType, + PresetName, + SocketName, + addon, ) -from .bl_socket_desc_map import BL_SOCKET_DESCR_ANNOT_STRING, BL_SOCKET_DESCR_TYPE_MAP -from .bl_socket_types import BL_SOCKET_DIRECT_TYPE_MAP +from .bl_socket_types import BLSocketInfo, BLSocketType from .category_labels import NODE_CAT_LABELS from .category_types import NodeCategory from .flow_events import FlowEvent from .flow_kinds import ( - ArrayFlow, - CapabilitiesFlow, - FlowKind, - InfoFlow, - LazyArrayRangeFlow, - LazyValueFuncFlow, - ParamsFlow, - ValueFlow, + ArrayFlow, + CapabilitiesFlow, + FlowKind, + InfoFlow, + LazyArrayRangeFlow, + LazyValueFuncFlow, + ParamsFlow, + ValueFlow, ) from .flow_signals import FlowSignal from .icons import Icon @@ -43,7 +42,6 @@ from .mobj_types import ManagedObjType from .node_types import NodeType from .socket_colors import SOCKET_COLORS from .socket_types import SocketType -from .socket_units import SOCKET_UNITS, unit_to_socket_type from .tree_types import TreeType from .unit_systems import UNITS_BLENDER, UNITS_TIDY3D @@ -72,15 +70,12 @@ __all__ = [ 'Icon', 'TreeType', 'SocketType', - 'SOCKET_UNITS', - 'unit_to_socket_type', 'SOCKET_COLORS', 'SOCKET_SHAPES', 'UNITS_BLENDER', 'UNITS_TIDY3D', - 'BL_SOCKET_DESCR_TYPE_MAP', - 'BL_SOCKET_DIRECT_TYPE_MAP', - 'BL_SOCKET_DESCR_ANNOT_STRING', + 'BLSocketInfo', + 'BLSocketType', 'NodeType', 'NodeCategory', 'NODE_CAT_LABELS', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_desc_map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_desc_map.py deleted file mode 100644 index 07a72ec..0000000 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_desc_map.py +++ /dev/null @@ -1,62 +0,0 @@ -from .socket_types import SocketType as ST - -BL_SOCKET_DESCR_ANNOT_STRING = ':: ' -BL_SOCKET_DESCR_TYPE_MAP = { - ('Time', 'NodeSocketFloat', 1): ST.PhysicalTime, - ('Angle', 'NodeSocketFloat', 1): ST.PhysicalAngle, - ('SolidAngle', 'NodeSocketFloat', 1): ST.PhysicalSolidAngle, - ('Rotation', 'NodeSocketVector', 2): ST.PhysicalRot2D, - ('Rotation', 'NodeSocketVector', 3): ST.PhysicalRot3D, - ('Freq', 'NodeSocketFloat', 1): ST.PhysicalFreq, - ('AngFreq', 'NodeSocketFloat', 1): ST.PhysicalAngFreq, - ## Cartesian - ('Length', 'NodeSocketFloat', 1): ST.PhysicalLength, - ('Area', 'NodeSocketFloat', 1): ST.PhysicalArea, - ('Volume', 'NodeSocketFloat', 1): ST.PhysicalVolume, - ('Disp', 'NodeSocketVector', 2): ST.PhysicalDisp2D, - ('Disp', 'NodeSocketVector', 3): ST.PhysicalDisp3D, - ('Point', 'NodeSocketFloat', 1): ST.PhysicalPoint1D, - ('Point', 'NodeSocketVector', 2): ST.PhysicalPoint2D, - ('Point', 'NodeSocketVector', 3): ST.PhysicalPoint3D, - ('Size', 'NodeSocketVector', 2): ST.PhysicalSize2D, - ('Size', 'NodeSocketVector', 3): ST.PhysicalSize3D, - ## Mechanical - ('Mass', 'NodeSocketFloat', 1): ST.PhysicalMass, - ('Speed', 'NodeSocketFloat', 1): ST.PhysicalSpeed, - ('Vel', 'NodeSocketVector', 2): ST.PhysicalVel2D, - ('Vel', 'NodeSocketVector', 3): ST.PhysicalVel3D, - ('Accel', 'NodeSocketFloat', 1): ST.PhysicalAccelScalar, - ('Accel', 'NodeSocketVector', 2): ST.PhysicalAccel2D, - ('Accel', 'NodeSocketVector', 3): ST.PhysicalAccel3D, - ('Force', 'NodeSocketFloat', 1): ST.PhysicalForceScalar, - ('Force', 'NodeSocketVector', 2): ST.PhysicalForce2D, - ('Force', 'NodeSocketVector', 3): ST.PhysicalForce3D, - ('Pressure', 'NodeSocketFloat', 1): ST.PhysicalPressure, - ## Energetic - ('Energy', 'NodeSocketFloat', 1): ST.PhysicalEnergy, - ('Power', 'NodeSocketFloat', 1): ST.PhysicalPower, - ('Temp', 'NodeSocketFloat', 1): ST.PhysicalTemp, - ## ELectrodynamical - ('Curr', 'NodeSocketFloat', 1): ST.PhysicalCurr, - ('CurrDens', 'NodeSocketVector', 2): ST.PhysicalCurrDens2D, - ('CurrDens', 'NodeSocketVector', 3): ST.PhysicalCurrDens3D, - ('Charge', 'NodeSocketFloat', 1): ST.PhysicalCharge, - ('Voltage', 'NodeSocketFloat', 1): ST.PhysicalVoltage, - ('Capacitance', 'NodeSocketFloat', 1): ST.PhysicalCapacitance, - ('Resistance', 'NodeSocketFloat', 1): ST.PhysicalResistance, - ('Conductance', 'NodeSocketFloat', 1): ST.PhysicalConductance, - ('MagFlux', 'NodeSocketFloat', 1): ST.PhysicalMagFlux, - ('MagFluxDens', 'NodeSocketFloat', 1): ST.PhysicalMagFluxDens, - ('Inductance', 'NodeSocketFloat', 1): ST.PhysicalInductance, - ('EField', 'NodeSocketFloat', 2): ST.PhysicalEField3D, - ('EField', 'NodeSocketFloat', 3): ST.PhysicalEField2D, - ('HField', 'NodeSocketFloat', 2): ST.PhysicalHField3D, - ('HField', 'NodeSocketFloat', 3): ST.PhysicalHField2D, - ## Luminal - ('LumIntensity', 'NodeSocketFloat', 1): ST.PhysicalLumIntensity, - ('LumFlux', 'NodeSocketFloat', 1): ST.PhysicalLumFlux, - ('Illuminance', 'NodeSocketFloat', 1): ST.PhysicalIlluminance, - ## Optical - ('PolJones', 'NodeSocketFloat', 2): ST.PhysicalPolJones, - ('Pol', 'NodeSocketFloat', 4): ST.PhysicalPol, -} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py index dfd2557..1ad0e29 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py @@ -1,39 +1,312 @@ -from .socket_types import SocketType as ST +import dataclasses +import enum +import typing as typ -BL_SOCKET_DIRECT_TYPE_MAP = { +import bpy +import sympy as sp + +from blender_maxwell.utils import blender_type_enum +from blender_maxwell.utils import extra_sympy_units as spux + +from .socket_types import SocketType + +BL_SOCKET_DESCR_ANNOT_STRING = ':: ' + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class BLSocketInfo: + has_support: bool + is_preview: bool + socket_type: SocketType | None + size: spux.NumberSize1D | None + physical_type: spux.PhysicalType | None + default_value: spux.ScalarUnitlessRealExpr + + bl_isocket_identifier: spux.ScalarUnitlessRealExpr + + +@blender_type_enum.prefix_values_with('NodeSocket') +class BLSocketType(enum.StrEnum): + Virtual = 'Virtual' # Blender - ('NodeSocketCollection', 1): ST.BlenderCollection, - ('NodeSocketImage', 1): ST.BlenderImage, - ('NodeSocketObject', 1): ST.BlenderObject, - ('NodeSocketMaterial', 1): ST.BlenderMaterial, + Image = 'Image' + Shader = 'Shader' + Material = 'Material' + Geometry = 'Material' + Object = 'Object' + Collection = 'Collection' # Basic - ('NodeSocketString', 1): ST.String, - ('NodeSocketBool', 1): ST.Bool, + Bool = 'Bool' + String = 'String' + Menu = 'Menu' # Float - ('NodeSocketFloat', 1): ST.RealNumber, - # ("NodeSocketFloatAngle", 1): ST.PhysicalAngle, - # ("NodeSocketFloatDistance", 1): ST.PhysicalLength, - ('NodeSocketFloatFactor', 1): ST.RealNumber, - ('NodeSocketFloatPercentage', 1): ST.RealNumber, - # ("NodeSocketFloatTime", 1): ST.PhysicalTime, - # ("NodeSocketFloatTimeAbsolute", 1): ST.PhysicalTime, + Float = 'Float' + FloatUnsigned = 'FloatUnsigned' + FloatAngle = 'FloatAngle' + FloatDistance = 'FloatDistance' + FloatFactor = 'FloatFactor' + FloatPercentage = 'FloatPercentage' + FloatTime = 'FloatTime' + FloatTimeAbsolute = 'FloatTimeAbsolute' # Int - ('NodeSocketInt', 1): ST.IntegerNumber, - ('NodeSocketIntFactor', 1): ST.IntegerNumber, - ('NodeSocketIntPercentage', 1): ST.IntegerNumber, - ('NodeSocketIntUnsigned', 1): ST.IntegerNumber, - # Array-Like - ('NodeSocketColor', 3): ST.Color, - ('NodeSocketRotation', 2): ST.PhysicalRot2D, - ('NodeSocketVector', 2): ST.Real2DVector, - ('NodeSocketVector', 3): ST.Real3DVector, - # ("NodeSocketVectorAcceleration", 2): ST.PhysicalAccel2D, - # ("NodeSocketVectorAcceleration", 3): ST.PhysicalAccel3D, - # ("NodeSocketVectorDirection", 2): ST.Real2DVectorDir, - # ("NodeSocketVectorDirection", 3): ST.Real3DVectorDir, - ('NodeSocketVectorEuler', 2): ST.PhysicalRot2D, - ('NodeSocketVectorEuler', 3): ST.PhysicalRot3D, - # ("NodeSocketVectorTranslation", 3): ST.PhysicalDisp3D, - # ("NodeSocketVectorVelocity", 3): ST.PhysicalVel3D, - # ("NodeSocketVectorXYZ", 3): ST.PhysicalPoint3D, -} + Int = 'Int' + IntFactor = 'IntFactor' + IntPercentage = 'IntPercentage' + IntUnsigned = 'IntUnsigned' + # Vector + Color = 'Color' + Rotation = 'Rotation' + Vector = 'Vector' + VectorAcceleration = 'Acceleration' + VectorDirection = 'Direction' + VectorEuler = 'Euler' + VectorTranslation = 'Translation' + VectorVelocity = 'Velocity' + VectorXYZ = 'XYZ' + + @staticmethod + def from_bl_isocket( + bl_isocket: bpy.types.NodeTreeInterfaceSocket, + ) -> typ.Self: + return BLSocketType[bl_isocket.bl_socket_idname] + + @staticmethod + def info_from_bl_isocket( + bl_isocket: bpy.types.NodeTreeInterfaceSocket, + ) -> typ.Self: + return BLSocketType.from_bl_isocket(bl_isocket).parse( + bl_isocket.default_value, bl_isocket.description, bl_isocket.identifier + ) + + #################### + # - Direct Properties + #################### + @property + def has_support(self) -> bool: + BLST = BLSocketType + return { + BLST.Virtual: False, + BLST.Geometry: False, + BLST.Shader: False, + BLST.FloatUnsigned: False, + BLST.IntUnsigned: False, + }.get(self, True) + + @property + def socket_type(self) -> SocketType | None: + """Deduce `SocketType` corresponding to the Blender socket type. + + **The socket type alone is not enough** to actually create the socket. + To declare a socket in the addon, an appropriate `socket.SocketDef` must be constructed in a manner that respects contextual information. + + Returns: + The corresponding socket type, if the addon has support for mapping this Blender socket. + For sockets with support, the fallback is always `SocketType.Expr`. + + Support is determined using `self.has_support` + """ + if not self.has_support: + return None + + BLST = BLSocketType + ST = SocketType + return { + # Blender + # Basic + BLST.Bool: ST.String, + # Float + # Array-Like + BLST.Color: ST.Color, + }.get(self, ST.Expr) + + @property + def mathtype(self) -> spux.MathType | None: + """Deduce `spux.MathType` corresponding to the Blender socket type. + + **The socket type alone is not enough** to actually create the socket. + To declare a socket in the addon, an appropriate `socket.SocketDef` must be constructed in a manner that respects contextual information. + + Returns: + The corresponding socket type, if the addon has support for mapping this Blender socket. + For sockets with support, the fallback is always `SocketType.Expr`. + + Support is determined using `self.has_support` + """ + if not self.has_support: + return None + + BLST = BLSocketType + MT = spux.MathType + return { + # Blender + # Basic + BLST.Bool: MT.Bool, + # Float + BLST.Float: MT.Real, + BLST.FloatAngle: MT.Real, + BLST.FloatDistance: MT.Real, + BLST.FloatFactor: MT.Real, + BLST.FloatPercentage: MT.Real, + BLST.FloatTime: MT.Real, + BLST.FloatTimeAbsolute: MT.Real, + # Int + BLST.Int: MT.Integer, + BLST.IntFactor: MT.Integer, + BLST.IntPercentage: MT.Integer, + # Vector + BLST.Color: MT.Real, + BLST.Rotation: MT.Real, + BLST.Vector: MT.Real, + BLST.VectorAcceleration: MT.Real, + BLST.VectorDirection: MT.Real, + BLST.VectorEuler: MT.Real, + BLST.VectorTranslation: MT.Real, + BLST.VectorVelocity: MT.Real, + BLST.VectorXYZ: MT.Real, + }.get(self) + + @property + def size( + self, + ) -> ( + typ.Literal[ + spux.NumberSize1D.Scalar, spux.NumberSize1D.Vec3, spux.NumberSize1D.Vec4 + ] + | None + ): + """Deduce the internal size of the Blender socket's data. + + Returns: + A `spux.NumberSize1D` reflecting the internal data representation. + Always falls back to `{spux.NumberSize1D.Scalar}`. + """ + if not self.has_support: + return None + + S = spux.NumberSize1D + BLST = BLSocketType + return { + BLST.Color: S.Vec4, + BLST.Rotation: S.Vec3, + BLST.VectorAcceleration: S.Vec3, + BLST.VectorDirection: S.Vec3, + BLST.VectorEuler: S.Vec3, + BLST.VectorTranslation: S.Vec3, + BLST.VectorVelocity: S.Vec3, + BLST.VectorXYZ: S.Vec3, + }.get(self, {S.Scalar}) + + @property + def unambiguous_physical_type(self) -> spux.PhysicalType | None: + """Deduce an **unambiguous** physical type from the Blender socket, if any. + + Blender does have its own unit systems, which leads to some Blender socket subtypes having an obvious choice of physical unit dimension (ex. `BLSocketType.FloatTime`). + In such cases, the `spux.PhysicalType` that matches the Blender socket can be uniquely determined. + + When a phsyical type cannot be immediately determined in this way, other mechanisms must be used to deduce what to do. + + Returns: + A physical type corresponding to the Blender socket, **if exactly one** can be determined with no ambiguity - else `None`. + + If more than one physical type might apply, `None`. + """ + if not self.has_support: + return None + + P = spux.PhysicalType + BLST = BLSocketType + { + BLST.FloatAngle: P.Angle, + BLST.FloatDistance: P.Length, + BLST.FloatTime: P.Time, + BLST.FloatTimeAbsolute: P.Time, ## What's the difference? + BLST.VectorAcceleration: P.Accel, + ## BLST.VectorDirection: Directions are unitless (within cartesian) + BLST.VectorEuler: P.Angle, + BLST.VectorTranslation: P.Length, + BLST.VectorVelocity: P.Vel, + BLST.VectorXYZ: P.Length, + }.get(self) + + @property + def valid_sizes(self) -> set[spux.NumberSize1D] | None: + """Deduce which sizes it would be valid to interpret a Blender socket as having. + + This property's purpose is merely to present a set of options that _are valid_. + Whether an a size _is truly usable_, can only be determined with contextual information, wherein certain decisions can be made: + + - **2D vs. 3D**: In general, Blender's vector socket types are **only** 3D, and we cannot ever _really_ have a 2D vector. + If one chooses interpret ex. `BLSocketType.Vector` as 2D, one might do so by pretending the third coordinate doesn't exist. + But **this is a subjective decision**, which always has to align with the logic on the other side of the Blender socket. + - **Colors**: Generally, `BLSocketType.Color` is always 4D, representing `RGBA` (with alpha channel). + However, we often don't care about alpha; therefore, we might choose to "just" push a 3D `RGB` vector. + Again, **this is a subjective decision** which requires one to make a decision about alpha, for example "alpha is always 1". + - **Scalars**: We can generally always interpret a scalar as a vector, using well-defined "broadcasting". + + Returns: + The set of `spux.NumberSize1D`s, which it would be safe to interpret the Blender socket as having. + + Always falls back to `{spux.NumberSize1D.Scalar}`. + """ + if not self.has_support: + return None + + S = spux.NumberSize1D + BLST = BLSocketType + return { + BLST.Color: {S.Scalar, S.Vec3, S.Vec4}, + BLST.Rotation: {S.Vec2, S.Vec3}, + BLST.VectorAcceleration: {S.Scalar, S.Vec2, S.Vec3}, + BLST.VectorDirection: {S.Scalar, S.Vec2, S.Vec3}, + BLST.VectorEuler: {S.Vec2, S.Vec3}, + BLST.VectorTranslation: {S.Scalar, S.Vec2, S.Vec3}, + BLST.VectorVelocity: {S.Scalar, S.Vec2, S.Vec3}, + BLST.VectorXYZ: {S.Scalar, S.Vec2, S.Vec3}, + }.get(self, {S.Scalar}) + + #################### + # - Parsed Properties + #################### + def parse( + self, bl_default_value: typ.Any, description: str, bl_isocket_identifier: str + ) -> BLSocketInfo: + # Parse the Description + ## TODO: Some kind of error on invalid parse if there is also no unambiguous physical type + descr_params = description.split(BL_SOCKET_DESCR_ANNOT_STRING)[0] + directive = ( + _tokens[0] + if (_tokens := descr_params.split(' '))[0] != '2D' + else _tokens[1] + ) + + ## Interpret the Description Parse + parsed_physical_type = getattr(spux.PhysicalType, directive, None) + physical_type = ( + self.unambiguous_physical_type + if self.unambiguous_physical_type is not None + else parsed_physical_type + ) + + # Parse the Default Value + if self.mathtype is not None: + if self.size == spux.NumberSize1D.Scalar: + default_value = self.mathtype.pytype(bl_default_value) + elif description.startswith('2D'): + default_value = sp.Matrix(tuple(bl_default_value)[:2]) + else: + default_value = sp.Matrix(tuple(bl_default_value)) + else: + default_value = bl_default_value + + # Return Parsed Socket Information + ## -> Combining directly known and parsed knowledge. + ## -> Should contain everything needed to match the Blender socket. + return BLSocketInfo( + has_support=self.has_support, + is_preview=(directive == 'Preview'), + socket_type=self.socket_type, + size=self.size, + physical_type=physical_type, + default_value=default_value, + bl_isocket_identifier=bl_isocket_identifier, + ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py index af58248..e990458 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py @@ -316,8 +316,12 @@ class LazyValueFuncFlow: """ func: LazyFunction - func_args: list[type] = dataclasses.field(default_factory=list) - func_kwargs: dict[str, type] = dataclasses.field(default_factory=dict) + func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field( + default_factory=list + ) + func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field( + default_factory=dict + ) supports_jax: bool = False supports_numba: bool = False @@ -432,9 +436,7 @@ class LazyArrayRangeFlow: unit: The unit of the generated array values - int_symbols: Set of integer-valued variables from which `start` and/or `stop` are determined. - real_symbols: Set of real-valued variables from which `start` and/or `stop` are determined. - complex_symbols: Set of complex-valued variables from which `start` and/or `stop` are determined. + symbols: Set of variables from which `start` and/or `stop` are determined. """ start: spux.ScalarUnitlessComplexExpr @@ -444,12 +446,10 @@ class LazyArrayRangeFlow: unit: spux.Unit | None = None - int_symbols: set[spux.IntSymbol] = frozenset() - real_symbols: set[spux.RealSymbol] = frozenset() - complex_symbols: set[spux.ComplexSymbol] = frozenset() + symbols: frozenset[spux.IntSymbol] = frozenset() @functools.cached_property - def symbols(self) -> list[sp.Symbol]: + def sorted_symbols(self) -> list[sp.Symbol]: """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. The order is guaranteed to be **deterministic**. @@ -457,10 +457,7 @@ class LazyArrayRangeFlow: Returns: All symbols valid for use in the expression. """ - return sorted( - self.int_symbols | self.real_symbols | self.complex_symbols, - key=lambda sym: sym.name, - ) + return sorted(self.symbols, key=lambda sym: sym.name) @functools.cached_property def mathtype(self) -> spux.MathType: @@ -508,9 +505,7 @@ class LazyArrayRangeFlow: steps=self.steps, scaling=self.scaling, unit=corrected_unit, - int_symbols=self.int_symbols, - real_symbols=self.real_symbols, - complex_symbols=self.complex_symbols, + symbols=self.symbols, ) msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"' @@ -530,15 +525,12 @@ class LazyArrayRangeFlow: """ if self.unit is not None: return LazyArrayRangeFlow( - start=spu.convert_to(self.start, unit), - stop=spu.convert_to(self.stop, unit), + start=spu.scale_to_unit(self.start * self.unit, unit), + stop=spu.scale_to_unit(self.stop * self.unit, unit), steps=self.steps, scaling=self.scaling, unit=unit, symbols=self.symbols, - int_symbols=self.int_symbols, - real_symbols=self.real_symbols, - complex_symbols=self.complex_symbols, ) msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' @@ -549,7 +541,7 @@ class LazyArrayRangeFlow: #################### def rescale_bounds( self, - scaler: typ.Callable[ + rescale_func: typ.Callable[ [spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr ], reverse: bool = False, @@ -570,18 +562,12 @@ class LazyArrayRangeFlow: A rescaled `LazyArrayRangeFlow`. """ return LazyArrayRangeFlow( - start=spu.convert_to( - scaler(self.start if not reverse else self.stop), self.unit - ), - stop=spu.convert_to( - scaler(self.stop if not reverse else self.start), self.unit - ), + start=rescale_func(self.start if not reverse else self.stop), + stop=rescale_func(self.stop if not reverse else self.start), steps=self.steps, scaling=self.scaling, unit=self.unit, - int_symbols=self.int_symbols, - real_symbols=self.real_symbols, - complex_symbols=self.complex_symbols, + symbols=self.symbols, ) #################### @@ -650,9 +636,7 @@ class LazyArrayRangeFlow: """ return LazyValueFuncFlow( func=self.as_func, - func_args=[ - (sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols - ], + func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols], supports_jax=True, ) @@ -709,13 +693,38 @@ class LazyArrayRangeFlow: #################### @dataclasses.dataclass(frozen=True, kw_only=True) class ParamsFlow: - func_args: list[typ.Any] = dataclasses.field(default_factory=list) - func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict) + func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) + func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) + #################### + # - Scaled Func Args + #################### + def scaled_func_args(self, unit_system: spux.UnitSystem): + """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" + return [ + spux.convert_to_unit_system(func_arg, unit_system, use_jax_array=True) + for func_arg in self.func_args + ] + + def scaled_func_kwargs(self, unit_system: spux.UnitSystem): + """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" + return { + arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True) + for arg_name, arg in self.func_args + } + + #################### + # - Operations + #################### def __or__( self, other: typ.Self, ): + """Combine two function parameter lists, such that the LHS will be concatenated with the RHS. + + Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters. + The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur. + """ return ParamsFlow( func_args=self.func_args + other.func_args, func_kwargs=self.func_kwargs | other.func_kwargs, @@ -723,10 +732,8 @@ class ParamsFlow: def compose_within( self, - enclosing_func_args: list[tuple[type]] = (), - enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), - enclosing_func_arg_units: dict[str, type] = MappingProxyType({}), - enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}), + enclosing_func_args: list[spux.SympyExpr] = (), + enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}), ) -> typ.Self: return ParamsFlow( func_args=self.func_args + list(enclosing_func_args), @@ -745,6 +752,8 @@ class InfoFlow: default_factory=dict ) ## TODO: Rename to dim_idxs + ## TODO: Add PhysicalType + @functools.cached_property def dim_lens(self) -> dict[str, int]: return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} @@ -769,12 +778,14 @@ class InfoFlow: ] # Output Information + ## TODO: Add PhysicalType output_name: str = dataclasses.field(default_factory=list) output_shape: tuple[int, ...] | None = dataclasses.field(default=None) output_mathtype: spux.MathType = dataclasses.field() output_unit: spux.Unit | None = dataclasses.field() # Pinned Dimension Information + ## TODO: Add PhysicalType pinned_dim_names: list[str] = dataclasses.field(default_factory=list) pinned_dim_values: dict[str, float | complex] = dataclasses.field( default_factory=dict diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py index 9d2f0bb..5fda2c9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_colors.py @@ -4,42 +4,13 @@ from .socket_types import SocketType as ST SOCKET_COLORS = { # Basic ST.Any: (0.9, 0.9, 0.9, 1.0), # Light Grey - ST.Data: (0.8, 0.8, 0.8, 1.0), # Light Grey ST.Bool: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.String: (0.7, 0.7, 0.7, 1.0), # Medium Light Grey ST.FilePath: (0.6, 0.6, 0.6, 1.0), # Medium Grey ST.Expr: (0.5, 0.5, 0.5, 1.0), # Medium Grey - # Number - ST.IntegerNumber: (0.5, 0.5, 1.0, 1.0), # Light Blue - ST.RationalNumber: (0.4, 0.4, 0.9, 1.0), # Medium Light Blue - ST.RealNumber: (0.3, 0.3, 0.8, 1.0), # Medium Blue - ST.ComplexNumber: (0.2, 0.2, 0.7, 1.0), # Dark Blue - # Vector - ST.Integer2DVector: (0.5, 1.0, 0.5, 1.0), # Light Green - ST.Real2DVector: (0.5, 1.0, 0.5, 1.0), # Light Green - ST.Complex2DVector: (0.4, 0.9, 0.4, 1.0), # Medium Light Green - ST.Integer3DVector: (0.3, 0.8, 0.3, 1.0), # Medium Green - ST.Real3DVector: (0.3, 0.8, 0.3, 1.0), # Medium Green - ST.Complex3DVector: (0.2, 0.7, 0.2, 1.0), # Dark Green # Physical ST.PhysicalUnitSystem: (1.0, 0.5, 0.5, 1.0), # Light Red - ST.PhysicalTime: (1.0, 0.5, 0.5, 1.0), # Light Red - ST.PhysicalAngle: (0.9, 0.45, 0.45, 1.0), # Medium Light Red - ST.PhysicalLength: (0.8, 0.4, 0.4, 1.0), # Medium Red - ST.PhysicalArea: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red - ST.PhysicalVolume: (0.6, 0.3, 0.3, 1.0), # Dark Red - ST.PhysicalPoint2D: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red - ST.PhysicalPoint3D: (0.6, 0.3, 0.3, 1.0), # Dark Red - ST.PhysicalSize2D: (0.7, 0.35, 0.35, 1.0), # Medium Dark Red - ST.PhysicalSize3D: (0.6, 0.3, 0.3, 1.0), # Dark Red - ST.PhysicalMass: (0.9, 0.6, 0.4, 1.0), # Light Orange - ST.PhysicalSpeed: (0.8, 0.55, 0.35, 1.0), # Medium Light Orange - ST.PhysicalAccelScalar: (0.7, 0.5, 0.3, 1.0), # Medium Orange - ST.PhysicalForceScalar: (0.6, 0.45, 0.25, 1.0), # Medium Dark Orange - ST.PhysicalAccel3D: (0.7, 0.5, 0.3, 1.0), # Medium Orange - ST.PhysicalForce3D: (0.6, 0.45, 0.25, 1.0), # Medium Dark Orange ST.PhysicalPol: (0.5, 0.4, 0.2, 1.0), # Dark Orange - ST.PhysicalFreq: (1.0, 0.7, 0.5, 1.0), # Light Peach # Blender ST.BlenderMaterial: (0.8, 0.6, 1.0, 1.0), # Lighter Purple ST.BlenderObject: (0.7, 0.5, 1.0, 1.0), # Light Purple @@ -56,6 +27,7 @@ SOCKET_COLORS = { ST.MaxwellBoundConds: (0.9, 0.8, 0.5, 1.0), # Light Gold ST.MaxwellBoundCond: (0.8, 0.7, 0.45, 1.0), # Medium Light Gold ST.MaxwellMonitor: (0.7, 0.6, 0.4, 1.0), # Medium Gold + ST.MaxwellMonitorData: (0.7, 0.6, 0.4, 1.0), # Medium Gold ST.MaxwellFDTDSim: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold ST.MaxwellFDTDSimData: (0.6, 0.5, 0.35, 1.0), # Medium Dark Gold ST.MaxwellSimGrid: (0.5, 0.4, 0.3, 1.0), # Dark Gold diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py index 46d42fa..58841f2 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_types.py @@ -5,31 +5,14 @@ from blender_maxwell.utils import blender_type_enum @blender_type_enum.append_cls_name_to_values class SocketType(blender_type_enum.BlenderTypeEnum): + Expr = enum.auto() + # Base Any = enum.auto() - Data = enum.auto() Bool = enum.auto() String = enum.auto() FilePath = enum.auto() Color = enum.auto() - Expr = enum.auto() - - # Number - IntegerNumber = enum.auto() - RationalNumber = enum.auto() - RealNumber = enum.auto() - ComplexNumber = enum.auto() - - # Vector - Integer2DVector = enum.auto() - Real2DVector = enum.auto() - Real2DVectorDir = enum.auto() - Complex2DVector = enum.auto() - - Integer3DVector = enum.auto() - Real3DVector = enum.auto() - Real3DVectorDir = enum.auto() - Complex3DVector = enum.auto() # Blender BlenderMaterial = enum.auto() @@ -53,6 +36,7 @@ class SocketType(blender_type_enum.BlenderTypeEnum): MaxwellStructure = enum.auto() MaxwellMonitor = enum.auto() + MaxwellMonitorData = enum.auto() MaxwellFDTDSim = enum.auto() MaxwellFDTDSimData = enum.auto() @@ -66,76 +50,5 @@ class SocketType(blender_type_enum.BlenderTypeEnum): # Physical PhysicalUnitSystem = enum.auto() - PhysicalTime = enum.auto() - - PhysicalAngle = enum.auto() - PhysicalSolidAngle = enum.auto() - - PhysicalRot2D = enum.auto() - PhysicalRot3D = enum.auto() - - PhysicalFreq = enum.auto() - PhysicalAngFreq = enum.auto() - - ## Cartesian - PhysicalLength = enum.auto() - PhysicalArea = enum.auto() - PhysicalVolume = enum.auto() - - PhysicalDisp2D = enum.auto() - PhysicalDisp3D = enum.auto() - - PhysicalPoint1D = enum.auto() - PhysicalPoint2D = enum.auto() - PhysicalPoint3D = enum.auto() - - PhysicalSize2D = enum.auto() - PhysicalSize3D = enum.auto() - - ## Mechanical - PhysicalMass = enum.auto() - - PhysicalSpeed = enum.auto() - PhysicalVel2D = enum.auto() - PhysicalVel3D = enum.auto() - PhysicalAccelScalar = enum.auto() - PhysicalAccel2D = enum.auto() - PhysicalAccel3D = enum.auto() - PhysicalForceScalar = enum.auto() - PhysicalForce2D = enum.auto() - PhysicalForce3D = enum.auto() - PhysicalPressure = enum.auto() - - ## Energetic - PhysicalEnergy = enum.auto() - PhysicalPower = enum.auto() - PhysicalTemp = enum.auto() - - ## Electrodynamical - PhysicalCurr = enum.auto() - PhysicalCurrDens2D = enum.auto() - PhysicalCurrDens3D = enum.auto() - - PhysicalCharge = enum.auto() - PhysicalVoltage = enum.auto() - PhysicalCapacitance = enum.auto() - PhysicalResistance = enum.auto() - PhysicalConductance = enum.auto() - - PhysicalMagFlux = enum.auto() - PhysicalMagFluxDens = enum.auto() - PhysicalInductance = enum.auto() - - PhysicalEField2D = enum.auto() - PhysicalEField3D = enum.auto() - PhysicalHField2D = enum.auto() - PhysicalHField3D = enum.auto() - - ## Luminal - PhysicalLumIntensity = enum.auto() - PhysicalLumFlux = enum.auto() - PhysicalIlluminance = enum.auto() - ## Optical - PhysicalPolJones = enum.auto() PhysicalPol = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py deleted file mode 100644 index 0b7c16c..0000000 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/socket_units.py +++ /dev/null @@ -1,287 +0,0 @@ -import sympy.physics.units as spu - -from blender_maxwell.utils import extra_sympy_units as spux - -from .socket_types import SocketType as ST # noqa: N817 - -SOCKET_UNITS = { - ST.PhysicalTime: { - 'default': 'PS', - 'values': { - 'FS': spux.femtosecond, - 'PS': spu.picosecond, - 'NS': spu.nanosecond, - 'MS': spu.microsecond, - 'MLSEC': spu.millisecond, - 'SEC': spu.second, - 'MIN': spu.minute, - 'HOUR': spu.hour, - 'DAY': spu.day, - }, - }, - ST.PhysicalAngle: { - 'default': 'RADIAN', - 'values': { - 'RADIAN': spu.radian, - 'DEGREE': spu.degree, - 'STERAD': spu.steradian, - 'ANGMIL': spu.angular_mil, - }, - }, - ST.PhysicalLength: { - 'default': 'UM', - 'values': { - 'PM': spu.picometer, - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - 'INCH': spu.inch, - 'FOOT': spu.foot, - 'YARD': spu.yard, - 'MILE': spu.mile, - }, - }, - ST.PhysicalArea: { - 'default': 'UM_SQ', - 'values': { - 'PM_SQ': spu.picometer**2, - 'A_SQ': spu.angstrom**2, - 'NM_SQ': spu.nanometer**2, - 'UM_SQ': spu.micrometer**2, - 'MM_SQ': spu.millimeter**2, - 'CM_SQ': spu.centimeter**2, - 'M_SQ': spu.meter**2, - 'INCH_SQ': spu.inch**2, - 'FOOT_SQ': spu.foot**2, - 'YARD_SQ': spu.yard**2, - 'MILE_SQ': spu.mile**2, - }, - }, - ST.PhysicalVolume: { - 'default': 'UM_CB', - 'values': { - 'PM_CB': spu.picometer**3, - 'A_CB': spu.angstrom**3, - 'NM_CB': spu.nanometer**3, - 'UM_CB': spu.micrometer**3, - 'MM_CB': spu.millimeter**3, - 'CM_CB': spu.centimeter**3, - 'M_CB': spu.meter**3, - 'ML': spu.milliliter, - 'L': spu.liter, - 'INCH_CB': spu.inch**3, - 'FOOT_CB': spu.foot**3, - 'YARD_CB': spu.yard**3, - 'MILE_CB': spu.mile**3, - }, - }, - ST.PhysicalPoint2D: { - 'default': 'UM', - 'values': { - 'PM': spu.picometer, - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - 'INCH': spu.inch, - 'FOOT': spu.foot, - 'YARD': spu.yard, - 'MILE': spu.mile, - }, - }, - ST.PhysicalPoint3D: { - 'default': 'UM', - 'values': { - 'PM': spu.picometer, - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - 'INCH': spu.inch, - 'FOOT': spu.foot, - 'YARD': spu.yard, - 'MILE': spu.mile, - }, - }, - ST.PhysicalSize2D: { - 'default': 'UM', - 'values': { - 'PM': spu.picometer, - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - 'INCH': spu.inch, - 'FOOT': spu.foot, - 'YARD': spu.yard, - 'MILE': spu.mile, - }, - }, - ST.PhysicalSize3D: { - 'default': 'UM', - 'values': { - 'PM': spu.picometer, - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - 'INCH': spu.inch, - 'FOOT': spu.foot, - 'YARD': spu.yard, - 'MILE': spu.mile, - }, - }, - ST.PhysicalMass: { - 'default': 'UG', - 'values': { - 'E_REST': spu.electron_rest_mass, - 'DAL': spu.dalton, - 'UG': spu.microgram, - 'MG': spu.milligram, - 'G': spu.gram, - 'KG': spu.kilogram, - 'TON': spu.metric_ton, - }, - }, - ST.PhysicalSpeed: { - 'default': 'UM_S', - 'values': { - 'PM_S': spu.picometer / spu.second, - 'NM_S': spu.nanometer / spu.second, - 'UM_S': spu.micrometer / spu.second, - 'MM_S': spu.millimeter / spu.second, - 'M_S': spu.meter / spu.second, - 'KM_S': spu.kilometer / spu.second, - 'KM_H': spu.kilometer / spu.hour, - 'FT_S': spu.feet / spu.second, - 'MI_H': spu.mile / spu.hour, - }, - }, - ST.PhysicalAccelScalar: { - 'default': 'UM_S_SQ', - 'values': { - 'PM_S_SQ': spu.picometer / spu.second**2, - 'NM_S_SQ': spu.nanometer / spu.second**2, - 'UM_S_SQ': spu.micrometer / spu.second**2, - 'MM_S_SQ': spu.millimeter / spu.second**2, - 'M_S_SQ': spu.meter / spu.second**2, - 'KM_S_SQ': spu.kilometer / spu.second**2, - 'FT_S_SQ': spu.feet / spu.second**2, - }, - }, - ST.PhysicalForceScalar: { - 'default': 'UNEWT', - 'values': { - 'KG_M_S_SQ': spu.kg * spu.m / spu.second**2, - 'NNEWT': spux.nanonewton, - 'UNEWT': spux.micronewton, - 'MNEWT': spux.millinewton, - 'NEWT': spu.newton, - }, - }, - ST.PhysicalAccel3D: { - 'default': 'UM_S_SQ', - 'values': { - 'PM_S_SQ': spu.picometer / spu.second**2, - 'NM_S_SQ': spu.nanometer / spu.second**2, - 'UM_S_SQ': spu.micrometer / spu.second**2, - 'MM_S_SQ': spu.millimeter / spu.second**2, - 'M_S_SQ': spu.meter / spu.second**2, - 'KM_S_SQ': spu.kilometer / spu.second**2, - 'FT_S_SQ': spu.feet / spu.second**2, - }, - }, - ST.PhysicalForce3D: { - 'default': 'UNEWT', - 'values': { - 'KG_M_S_SQ': spu.kg * spu.m / spu.second**2, - 'NNEWT': spux.nanonewton, - 'UNEWT': spux.micronewton, - 'MNEWT': spux.millinewton, - 'NEWT': spu.newton, - }, - }, - ST.PhysicalFreq: { - 'default': 'THZ', - 'values': { - 'HZ': spu.hertz, - 'KHZ': spux.kilohertz, - 'MHZ': spux.megahertz, - 'GHZ': spux.gigahertz, - 'THZ': spux.terahertz, - 'PHZ': spux.petahertz, - 'EHZ': spux.exahertz, - }, - }, - ST.PhysicalPol: { - 'default': 'RADIAN', - 'values': { - 'RADIAN': spu.radian, - 'DEGREE': spu.degree, - 'STERAD': spu.steradian, - 'ANGMIL': spu.angular_mil, - }, - }, - ST.MaxwellMedium: { - 'default': 'NM', - 'values': { - 'PM': spu.picometer, ## c(vac) = wl*freq - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - }, - }, - ST.MaxwellMonitor: { - 'default': 'NM', - 'values': { - 'PM': spu.picometer, ## c(vac) = wl*freq - 'A': spu.angstrom, - 'NM': spu.nanometer, - 'UM': spu.micrometer, - 'MM': spu.millimeter, - 'CM': spu.centimeter, - 'M': spu.meter, - }, - }, -} - - -def unit_to_socket_type( - unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None -) -> ST: - """Returns a SocketType that accepts the given unit. - - Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned. - This isn't super clean, but it's good enough for our needs right now. - - Returns: - **The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility. - """ - if unit is None and fallback_mathtype is not None: - return { - spux.MathType.Integer: ST.IntegerNumber, - spux.MathType.Rational: ST.RationalNumber, - spux.MathType.Real: ST.RealNumber, - spux.MathType.Complex: ST.ComplexNumber, - }[fallback_mathtype] - - for socket_type, _units in SOCKET_UNITS.items(): - if unit in _units['values'].values(): - return socket_type - - msg = f"Unit {unit} doesn't have an obvious SocketType." - raise ValueError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py index 633f574..78f9aa2 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py @@ -1,59 +1,69 @@ +"""Specifies unit systems for use in the node tree. + +Attributes: + UNITS_BLENDER: A unit system that serves as a reasonable default for the a 3D workspace that interprets the results of electromagnetic simulations. + **NOTE**: The addon _specifically_ neglects to respect Blender's builtin units. + In testing, Blender's system was found to be extremely brittle when "going small" like this; in particular, `picosecond`-order time units were impossible to specify functionally. + UNITS_TIDY3D: A unit system that aligns with Tidy3D's simulator. + See +""" + import typing as typ import sympy.physics.units as spu from blender_maxwell.utils import extra_sympy_units as spux -from blender_maxwell.utils.pydantic_sympy import SympyExpr -from .socket_types import SocketType as ST # noqa: N817 -from .socket_units import SOCKET_UNITS - - -def _socket_units(socket_type): - return SOCKET_UNITS[socket_type]['values'] - - -UnitSystem: typ.TypeAlias = dict[ST, SympyExpr] #################### # - Unit Systems #################### -UNITS_BLENDER: UnitSystem = { - ST.PhysicalTime: spu.picosecond, - ST.PhysicalAngle: spu.radian, - ST.PhysicalLength: spu.micrometer, - ST.PhysicalArea: spu.micrometer**2, - ST.PhysicalVolume: spu.micrometer**3, - ST.PhysicalPoint2D: spu.micrometer, - ST.PhysicalPoint3D: spu.micrometer, - ST.PhysicalSize2D: spu.micrometer, - ST.PhysicalSize3D: spu.micrometer, - ST.PhysicalMass: spu.microgram, - ST.PhysicalSpeed: spu.um / spu.second, - ST.PhysicalAccelScalar: spu.um / spu.second**2, - ST.PhysicalForceScalar: spux.micronewton, - ST.PhysicalAccel3D: spu.um / spu.second**2, - ST.PhysicalForce3D: spux.micronewton, - ST.PhysicalFreq: spux.terahertz, - ST.PhysicalPol: spu.radian, +_PT: typ.TypeAlias = spux.PhysicalType +UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | { + # Global + _PT.Time: spu.picosecond, + _PT.Freq: spux.terahertz, + _PT.AngFreq: spu.radian * spux.terahertz, + # Cartesian + _PT.Length: spu.micrometer, + _PT.Area: spu.micrometer**2, + _PT.Volume: spu.micrometer**3, + # Energy + _PT.PowerFlux: spu.watt / spu.um**2, + # Electrodynamics + _PT.CurrentDensity: spu.ampere / spu.um**2, + _PT.Conductivity: spu.siemens / spu.um, + _PT.PoyntingVector: spu.watt / spu.um**2, + _PT.EField: spu.volt / spu.um, + _PT.HField: spu.ampere / spu.um, + # Mechanical + _PT.Vel: spu.um / spu.second, + _PT.Accel: spu.um / spu.second, + _PT.Mass: spu.microgram, + _PT.Force: spux.micronewton, + # Luminal + # Optics + _PT.PoyntingVector: spu.watt / spu.um**2, } ## TODO: Load (dynamically?) from addon preferences -UNITS_TIDY3D: UnitSystem = { - ## https://docs.flexcompute.com/projects/tidy3d/en/latest/faq/docs/faq/What-are-the-units-used-in-the-simulation.html - ST.PhysicalTime: spu.second, - ST.PhysicalAngle: spu.radian, - ST.PhysicalLength: spu.micrometer, - ST.PhysicalArea: spu.micrometer**2, - ST.PhysicalVolume: spu.micrometer**3, - ST.PhysicalPoint2D: spu.micrometer, - ST.PhysicalPoint3D: spu.micrometer, - ST.PhysicalSize2D: spu.micrometer, - ST.PhysicalSize3D: spu.micrometer, - ST.PhysicalMass: spu.microgram, - ST.PhysicalSpeed: spu.um / spu.second, - ST.PhysicalAccelScalar: spu.um / spu.second**2, - ST.PhysicalForceScalar: spux.micronewton, - ST.PhysicalAccel3D: spu.um / spu.second**2, - ST.PhysicalForce3D: spux.micronewton, - ST.PhysicalFreq: spu.hertz, - ST.PhysicalPol: spu.radian, +UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | { + # Global + # Cartesian + _PT.Length: spu.um, + _PT.Area: spu.um**2, + _PT.Volume: spu.um**3, + # Mechanical + _PT.Vel: spu.um / spu.second, + _PT.Accel: spu.um / spu.second, + # Energy + _PT.PowerFlux: spu.watt / spu.um**2, + # Electrodynamics + _PT.CurrentDensity: spu.ampere / spu.um**2, + _PT.Conductivity: spu.siemens / spu.um, + _PT.PoyntingVector: spu.watt / spu.um**2, + _PT.EField: spu.volt / spu.um, + _PT.HField: spu.ampere / spu.um, + # Luminal + # Optics + _PT.PoyntingVector: spu.watt / spu.um**2, + ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py index 4ef00eb..d561553 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py @@ -101,6 +101,10 @@ class ManagedBLMesh(base.ManagedObj): #################### # - Methods #################### + @property + def exists(self) -> bool: + return bpy.data.objects.get(self.name) is not None + def show_preview(self) -> None: """Moves the managed Blender object to the preview collection. diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py index 52c0332..816d5d9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py @@ -4,7 +4,8 @@ import typing as typ import bpy -from blender_maxwell.utils import analyze_geonodes, logger +from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import logger from .. import bl_socket_map from .. import contracts as ct @@ -25,13 +26,12 @@ class ModifierAttrsNODES(typ.TypedDict): node_group: The GeoNodes group to use in the modifier. unit_system: The unit system used by the GeoNodes output. Generally, `ct.UNITS_BLENDER` is a good choice. - inputs: Values to associate with each GeoNodes interface socket. - Use `analyze_geonodes.interface(..., direc='INPUT')` to determine acceptable values. + inputs: Values to associate with each GeoNodes interface socket name. """ node_group: bpy.types.GeometryNodeTree unit_system: UnitSystem - inputs: dict[ct.BLNodeTreeInterfaceID, typ.Any] + inputs: dict[ct.SocketName, typ.Any] class ModifierAttrsARRAY(typ.TypedDict): @@ -47,7 +47,7 @@ MODIFIER_NAMES = { #################### -# - Read Modifier Information +# - Read Modifier #################### def read_modifier(bl_modifier: bpy.types.Modifier) -> ModifierAttrs: if bl_modifier.type == 'NODES': @@ -62,7 +62,7 @@ def read_modifier(bl_modifier: bpy.types.Modifier) -> ModifierAttrs: #################### -# - Write Modifier Information +# - Write Modifier: GeoNodes #################### def write_modifier_geonodes( bl_modifier: bpy.types.NodesModifier, @@ -78,6 +78,7 @@ def write_modifier_geonodes( True if the modifier was altered. """ modifier_altered = False + # Alter GeoNodes Group if bl_modifier.node_group != modifier_attrs['node_group']: log.info( @@ -89,53 +90,22 @@ def write_modifier_geonodes( modifier_altered = True # Alter GeoNodes Modifier Inputs - ## First we retrieve the interface items by-Socket Name - geonodes_interface = analyze_geonodes.interface( - bl_modifier.node_group, direc='INPUT' - ) - for ( - socket_name, - value, - ) in modifier_attrs['inputs'].items(): - # Compute Writable BL Socket Value - ## Analyzes the socket and unitsys to prep a ready-to-write value. - ## Write directly to the modifier dict. - bl_socket_value = bl_socket_map.writable_bl_socket_value( - geonodes_interface[socket_name], - value, - unit_system=modifier_attrs['unit_system'], - allow_unit_not_in_unit_system=True, + socket_infos = bl_socket_map.info_from_geonodes(bl_modifier.node_group) + + for socket_name in modifier_attrs['inputs']: + iface_id = socket_infos[socket_name].bl_isocket_identifier + bl_modifier[iface_id] = spux.scale_to_unit_system( + modifier_attrs['inputs'][socket_name], modifier_attrs['unit_system'] ) + modifier_altered = True + ## TODO: More fine-grained alterations - # Compute Interface ID from Socket Name - ## We can't index the modifier by socket name; only by Interface ID. - ## Still, we require that socket names are unique. - iface_id = geonodes_interface[socket_name].identifier - - # IF List-Like: Alter Differing Elements - if isinstance(bl_socket_value, tuple): - for i, bl_socket_subvalue in enumerate(bl_socket_value): - if bl_modifier[iface_id][i] != bl_socket_subvalue: - bl_modifier[iface_id][i] = bl_socket_subvalue - modifier_altered = True - - # IF int/float Mismatch: Assign Float-Cast of Integer - ## Blender is strict; only floats can set float vals. - ## We are less strict; if the user passes an int, that's okay. - elif isinstance(bl_socket_value, int) and isinstance( - bl_modifier[iface_id], - float, - ): - bl_modifier[iface_id] = float(bl_socket_value) - modifier_altered = True - else: - ## TODO: Whitelist what can be here. I'm done with the TypeErrors. - bl_modifier[iface_id] = bl_socket_value - modifier_altered = True - - return modifier_altered + return modifier_altered # noqa: RET504 +#################### +# - Write Modifier +#################### def write_modifier( bl_modifier: bpy.types.Modifier, modifier_attrs: ModifierAttrs, @@ -184,8 +154,11 @@ class ManagedBLModifier(base.ManagedObj): def __init__(self, name: str): self.name = name - def bl_select(self) -> None: pass - def hide_preview(self) -> None: pass + def bl_select(self) -> None: + pass + + def hide_preview(self) -> None: + pass #################### # - Deallocation @@ -255,7 +228,8 @@ class ManagedBLModifier(base.ManagedObj): type=modifier_type, ) - if modifier_altered := write_modifier(bl_modifier, modifier_attrs): + modifier_altered = write_modifier(bl_modifier, modifier_attrs) + if modifier_altered: bl_object.data.update() return bl_modifier diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index e60f48d..b596953 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -41,18 +41,17 @@ class ExtractDataNode(base.MaxwellSimNode): input_socket_sets: typ.ClassVar = { 'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()}, - 'Monitor Data': {'Monitor Data': sockets.DataSocketDef(format='monitor_data')}, + 'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, } output_socket_sets: typ.ClassVar = { - 'Sim Data': {'Monitor Data': sockets.DataSocketDef(format='monitor_data')}, - 'Monitor Data': {'Data': sockets.DataSocketDef(format='jax')}, + 'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, + 'Monitor Data': {'Expr': sockets.ExprSocketDef()}, } #################### # - Properties #################### extract_filter: enum.Enum = bl_cache.BLField( - None, prop_ui=True, enum_cb=lambda self, _: self.search_extract_filters(), ) @@ -62,7 +61,7 @@ class ExtractDataNode(base.MaxwellSimNode): #################### @property def sim_data(self) -> td.SimulationData | None: - """Computes the (cached) simulation data from the input socket. + """Extracts the simulation data from the input socket. Return: Either the simulation data, if available, or None. @@ -70,14 +69,15 @@ class ExtractDataNode(base.MaxwellSimNode): sim_data = self._compute_input( 'Sim Data', kind=ct.FlowKind.Value, optional=True ) - if not ct.FlowSignal.check(sim_data): + has_sim_data = not ct.FlowSignal.check(sim_data) + if has_sim_data: return sim_data return None @bl_cache.cached_bl_property() def sim_data_monitor_nametype(self) -> dict[str, str] | None: - """For simulation data, computes and and caches a map from name to "type". + """For simulation data, deduces a map from the monitor name to the monitor "type". Return: The name to type of monitors in the simulation data. @@ -95,7 +95,7 @@ class ExtractDataNode(base.MaxwellSimNode): #################### @property def monitor_data(self) -> TDMonitorData | None: - """Computes the (cached) monitor data from the input socket. + """Extracts the monitor data from the input socket. Return: Either the monitor data, if available, or None. @@ -103,17 +103,26 @@ class ExtractDataNode(base.MaxwellSimNode): monitor_data = self._compute_input( 'Monitor Data', kind=ct.FlowKind.Value, optional=True ) - if not ct.FlowSignal.check(monitor_data): + has_monitor_data = not ct.FlowSignal.check(monitor_data) + if has_monitor_data: return monitor_data return None @bl_cache.cached_bl_property() def monitor_data_type(self) -> str | None: - """For monitor data, computes and caches the monitor "type". + r"""For monitor data, deduces the monitor "type". + + - **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain. + - **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor. + - **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain. + For planes, an explicit direction is defined. + For volumes, the the integral of all outgoing energy is stored. + - **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection. + - **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order. Notes: - Should be invalidated with (before) `self.monitor_data_components`. + Should be invalidated with (before) `self.monitor_data_attrs`. Return: The "type" of the monitor, if available, else None. @@ -124,10 +133,10 @@ class ExtractDataNode(base.MaxwellSimNode): return None @bl_cache.cached_bl_property() - def monitor_data_components(self) -> list[str] | None: - r"""For monitor data, computes and caches the component sof the monitor. + def monitor_data_attrs(self) -> list[str] | None: + r"""For monitor data, deduces the valid data-containing attributes. - The output depends entirely on the output of `self.monitor_data`. + The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups. - **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor. - **Permittivity**: Specifically `['xx', 'yy', 'zz']`. @@ -183,7 +192,7 @@ class ExtractDataNode(base.MaxwellSimNode): """Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`. Notes: - Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_components`, and (implicitly) `self.monitor_type`. + Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`. See `bl_cache.BLField` for more on dynamic `EnumProperty`. @@ -198,26 +207,71 @@ class ExtractDataNode(base.MaxwellSimNode): ) ] - if self.monitor_data_components is not None: - return [ - ( - component_name, - component_name, - f'ℂ {component_name[1]}-polarization of the {"electric" if component_name[0] == "E" else "magnetic"} field', - '', - i, - ) - for i, component_name in enumerate(self.monitor_data_components) - ] + if self.monitor_data_attrs is not None: + # Field/FieldTime + if self.monitor_data_type in ['Field', 'FieldTime']: + return [ + ( + monitor_attr, + monitor_attr, + f'ℂ {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field', + '', + i, + ) + for i, monitor_attr in enumerate(self.monitor_data_attrs) + ] + + # Permittivity + if self.monitor_data_type == 'Permittivity': + return [ + (monitor_attr, monitor_attr, f'ℂ ε_{monitor_attr}', '', i) + for i, monitor_attr in enumerate(self.monitor_data_attrs) + ] + + # Flux/FluxTime + if self.monitor_data_type in ['Flux', 'FluxTime']: + return [ + ( + monitor_attr, + monitor_attr, + 'Power flux integral through the plane / out of the volume', + '', + i, + ) + for i, monitor_attr in enumerate(self.monitor_data_attrs) + ] + + # FieldProjection(Angle/Cartesian/KSpace)/Diffraction + if self.monitor_data_type in [ + 'FieldProjectionAngle', + 'FieldProjectionCartesian', + 'FieldProjectionKSpace', + 'Diffraction', + ]: + return [ + ( + monitor_attr, + monitor_attr, + f'ℂ {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field', + '', + i, + ) + for i, monitor_attr in enumerate(self.monitor_data_attrs) + ] return [] #################### # - UI #################### - def draw_label(self): + def draw_label(self) -> None: + """Show the extracted data (if any) in the node's header label. + + Notes: + Called by Blender to determine the text to place in the node's header. + """ has_sim_data = self.sim_data_monitor_nametype is not None - has_monitor_data = self.monitor_data_components is not None + has_monitor_data = self.monitor_data_attrs is not None if has_sim_data or has_monitor_data: return f'Extract: {self.extract_filter}' @@ -245,11 +299,11 @@ class ExtractDataNode(base.MaxwellSimNode): """Invalidate the cached properties for sim data / monitor data, and reset the extraction filter.""" self.sim_data_monitor_nametype = bl_cache.Signal.InvalidateCache self.monitor_data_type = bl_cache.Signal.InvalidateCache - self.monitor_data_components = bl_cache.Signal.InvalidateCache + self.monitor_data_attrs = bl_cache.Signal.InvalidateCache self.extract_filter = bl_cache.Signal.ResetEnumItems #################### - # - Output: Sim Data -> Monitor Data + # - Output (Value): Sim Data -> Monitor Data #################### @events.computes_output_socket( # Trigger @@ -262,96 +316,84 @@ class ExtractDataNode(base.MaxwellSimNode): def compute_monitor_data( self, props: dict, input_sockets: dict ) -> TDMonitorData | ct.FlowSignal: - """Compute `Monitor Data` by querying an attribute of `Sim Data`. - - Notes: - The attribute to query is read directly from `self.extract_filter`. - This is also the mechanism that protects from trying to reference an invalid attribute. + """Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`. Returns: Monitor data, if available, else `ct.FlowSignal.FlowPending`. """ + extract_filter = props['extract_filter'] sim_data = input_sockets['Sim Data'] has_sim_data = not ct.FlowSignal.check(sim_data) - if has_sim_data and props['extract_filter'] != 'NONE': - return input_sockets['Sim Data'].monitor_data[props['extract_filter']] - - # Propagate NoFlow - if ct.FlowSignal.check_single(sim_data, ct.FlowSignal.NoFlow): - return ct.FlowSignal.NoFlow + if has_sim_data and extract_filter is not None: + return sim_data.monitor_data[extract_filter] return ct.FlowSignal.FlowPending #################### - # - Output: Monitor Data -> Data + # - Output (Array): Monitor Data -> Expr #################### @events.computes_output_socket( # Trigger - 'Data', + 'Expr', kind=ct.FlowKind.Array, # Loaded props={'extract_filter'}, input_sockets={'Monitor Data'}, input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, ) - def compute_data( + def compute_expr( self, props: dict, input_sockets: dict ) -> jax.Array | ct.FlowSignal: - """Compute `Data:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow`. + """Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it. Uses the internal `xarray` data returned by Tidy3D. By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy. - Notes: - The attribute to query is read directly from `self.extract_filter`. - This is also the mechanism that protects from trying to reference an invalid attribute. - - Used as the first part of the `LazyFuncValue` chain used for further array manipulations with Math nodes. - Returns: The data array, if available, else `ct.FlowSignal.FlowPending`. """ - has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data']) + extract_filter = props['extract_filter'] + monitor_data = input_sockets['Monitor Data'] + has_monitor_data = not ct.FlowSignal.check(monitor_data) - if has_monitor_data and props['extract_filter'] != 'NONE': - xarray_data = getattr( - input_sockets['Monitor Data'], props['extract_filter'] - ) + if has_monitor_data and extract_filter is not None: + xarray_data = getattr(monitor_data, extract_filter) return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None) return ct.FlowSignal.FlowPending @events.computes_output_socket( # Trigger - 'Data', + 'Expr', kind=ct.FlowKind.LazyValueFunc, # Loaded - output_sockets={'Data'}, - output_socket_kinds={'Data': ct.FlowKind.Array}, + output_sockets={'Expr'}, + output_socket_kinds={'Expr': ct.FlowKind.Array}, ) def compute_extracted_data_lazy( self, output_sockets: dict ) -> ct.LazyValueFuncFlow | None: - """Declare `Data:LazyValueFunc` by creating a simple function that directly wraps `Data:Array`. + """Declare `Expr:LazyValueFunc` by creating a simple function that directly wraps `Expr:Array`. Returns: The composable function array, if available, else `ct.FlowSignal.FlowPending`. """ - has_output_data = not ct.FlowSignal.check(output_sockets['Data']) + output_expr = output_sockets['Expr'] + has_output_expr = not ct.FlowSignal.check(output_expr) - if has_output_data: + if has_output_expr: return ct.LazyValueFuncFlow( - func=lambda: output_sockets['Data'].values, supports_jax=True + func=lambda: output_expr.values, supports_jax=True ) return ct.FlowSignal.FlowPending #################### - # - Auxiliary: Monitor Data -> Data + # - Auxiliary (Params): Monitor Data -> Expr #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Params, ) def compute_data_params(self) -> ct.ParamsFlow: @@ -362,6 +404,9 @@ class ExtractDataNode(base.MaxwellSimNode): """ return ct.ParamsFlow() + #################### + # - Auxiliary (Info): Monitor Data -> Expr + #################### @events.computes_output_socket( # Trigger 'Data', @@ -380,20 +425,21 @@ class ExtractDataNode(base.MaxwellSimNode): Returns: Information describing the `Data:LazyValueFunc`, if available, else `ct.FlowSignal.FlowPending`. """ - has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data']) + monitor_data = input_sockets['Monitor Data'] + monitor_data_type = props['monitor_data_type'] + extract_filter = props['extract_filter'] + + has_monitor_data = not ct.FlowSignal.check(monitor_data) # Retrieve XArray - if has_monitor_data and props['extract_filter'] != 'NONE': - xarr = getattr(input_sockets['Monitor Data'], props['extract_filter']) + if has_monitor_data and extract_filter is not None: + xarr = getattr(monitor_data, extract_filter) else: return ct.FlowSignal.FlowPending - info_output_name = props['extract_filter'] - info_output_shape = None - # Compute InfoFlow from XArray ## XYZF: Field / Permittivity / FieldProjectionCartesian - if props['monitor_data_type'] in { + if monitor_data_type in { 'Field', 'Permittivity', #'FieldProjectionCartesian', @@ -413,18 +459,16 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Complex, output_unit=( - spu.volt / spu.micrometer - if props['monitor_data_type'] == 'Field' - else None + spu.volt / spu.micrometer if monitor_data_type == 'Field' else None ), ) ## XYZT: FieldTime - if props['monitor_data_type'] == 'FieldTime': + if monitor_data_type == 'FieldTime': return ct.InfoFlow( dim_names=['x', 'y', 'z', 't'], dim_idx={ @@ -440,18 +484,16 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Complex, output_unit=( - spu.volt / spu.micrometer - if props['monitor_data_type'] == 'Field' - else None + spu.volt / spu.micrometer if monitor_data_type == 'Field' else None ), ) ## F: Flux - if props['monitor_data_type'] == 'Flux': + if monitor_data_type == 'Flux': return ct.InfoFlow( dim_names=['f'], dim_idx={ @@ -461,14 +503,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Real, output_unit=spu.watt, ) ## T: FluxTime - if props['monitor_data_type'] == 'FluxTime': + if monitor_data_type == 'FluxTime': return ct.InfoFlow( dim_names=['t'], dim_idx={ @@ -478,14 +520,14 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Real, output_unit=spu.watt, ) ## RThetaPhiF: FieldProjectionAngle - if props['monitor_data_type'] == 'FieldProjectionAngle': + if monitor_data_type == 'FieldProjectionAngle': return ct.InfoFlow( dim_names=['r', 'theta', 'phi', 'f'], dim_idx={ @@ -508,18 +550,18 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Real, output_unit=( spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') + if extract_filter.startswith('E') else spu.ampere / spu.micrometer ), ) ## UxUyRF: FieldProjectionKSpace - if props['monitor_data_type'] == 'FieldProjectionKSpace': + if monitor_data_type == 'FieldProjectionKSpace': return ct.InfoFlow( dim_names=['ux', 'uy', 'r', 'f'], dim_idx={ @@ -540,18 +582,18 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Real, output_unit=( spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') + if extract_filter.startswith('E') else spu.ampere / spu.micrometer ), ) ## OrderxOrderyF: Diffraction - if props['monitor_data_type'] == 'Diffraction': + if monitor_data_type == 'Diffraction': return ct.InfoFlow( dim_names=['orders_x', 'orders_y', 'f'], dim_idx={ @@ -569,17 +611,17 @@ class ExtractDataNode(base.MaxwellSimNode): is_sorted=True, ), }, - output_name=props['extract_filter'], + output_name=extract_filter, output_shape=None, output_mathtype=spux.MathType.Real, output_unit=( spu.volt / spu.micrometer - if props['extract_filter'].startswith('E') + if extract_filter.startswith('E') else spu.ampere / spu.micrometer ), ) - msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' + msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"' raise RuntimeError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py index 0517859..b7b222f 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py @@ -1,16 +1,16 @@ -from . import filter_math, map_math, operate_math, reduce_math, transform_math +from . import filter_math, map_math, operate_math # , #reduce_math, transform_math BL_REGISTER = [ + *operate_math.BL_REGISTER, *map_math.BL_REGISTER, *filter_math.BL_REGISTER, - *reduce_math.BL_REGISTER, - *operate_math.BL_REGISTER, - *transform_math.BL_REGISTER, + # *reduce_math.BL_REGISTER, + # *transform_math.BL_REGISTER, ] BL_NODES = { + **operate_math.BL_NODES, **map_math.BL_NODES, **filter_math.BL_NODES, - **reduce_math.BL_NODES, - **operate_math.BL_NODES, - **transform_math.BL_NODES, + # **reduce_math.BL_NODES, + # **transform_math.BL_NODES, } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 4407ba4..2df631f 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -4,10 +4,10 @@ import enum import typing as typ import bpy -import jax import jax.numpy as jnp from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct from .... import sockets @@ -16,6 +16,78 @@ from ... import base, events log = logger.get(__name__) +class FilterOperation(enum.StrEnum): + """Valid operations for the `FilterMathNode`. + + Attributes: + DimToVec: Shift last dimension to output. + DimsToMat: Shift last 2 dimensions to output. + PinLen1: Remove a len(1) dimension. + Pin: Remove a len(n) dimension by selecting a particular index. + Swap: Swap the positions of two dimensions. + """ + + # Dimensions + PinLen1 = enum.auto() + Pin = enum.auto() + Swap = enum.auto() + + # Interpret + DimToVec = enum.auto() + DimsToMat = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + FO = FilterOperation + return { + # Dimensions + FO.PinLen1: 'pinₐ =1', + FO.Pin: 'pinₐ ≈v', + FO.Swap: 'a₁ ↔ a₂', + # Interpret + FO.DimToVec: '→ Vector', + FO.DimsToMat: '→ Matrix', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def are_dims_valid(self, dim_0: int | None, dim_1: int | None): + return not ( + ( + dim_0 is None + and self + in [FilterOperation.PinLen1, FilterOperation.Pin, FilterOperation.Swap] + ) + or (dim_1 is None and self == FilterOperation.Swap) + ) + + def jax_func(self, axis_0: int | None, axis_1: int | None): + return { + # Interpret + FilterOperation.DimToVec: lambda data: data, + FilterOperation.DimsToMat: lambda data: data, + # Dimensions + FilterOperation.PinLen1: lambda data: jnp.squeeze(data, axis_0), + FilterOperation.Pin: lambda data, fixed_axis_idx: jnp.take( + data, fixed_axis_idx, axis=axis_0 + ), + FilterOperation.Swap: lambda data: jnp.swapaxes(data, axis_0, axis_1), + }[self] + + def transform_info(self, info: ct.InfoFlow, dim_0: str, dim_1: str): + return { + # Interpret + FilterOperation.DimToVec: lambda: info.shift_last_input, + FilterOperation.DimsToMat: lambda: info.shift_last_input.shift_last_input, + # Dimensions + FilterOperation.PinLen1: lambda: info.delete_dimension(dim_0), + FilterOperation.Pin: lambda: info.delete_dimension(dim_0), + FilterOperation.Swap: lambda: info.swap_dimensions(dim_0, dim_1), + }[self]() + + class FilterMathNode(base.MaxwellSimNode): r"""Applies a function that operates on the shape of the array. @@ -38,21 +110,18 @@ class FilterMathNode(base.MaxwellSimNode): bl_label = 'Filter Math' input_sockets: typ.ClassVar = { - 'Data': sockets.DataSocketDef(format='jax'), - } - input_socket_sets: typ.ClassVar = { - 'Interpret': {}, - 'Dimensions': {}, + 'Expr': sockets.ExprSocketDef(), } output_sockets: typ.ClassVar = { - 'Data': sockets.DataSocketDef(format='jax'), + 'Expr': sockets.ExprSocketDef(), } #################### # - Properties #################### - operation: enum.Enum = bl_cache.BLField( - prop_ui=True, enum_cb=lambda self, _: self.search_operations() + operation: FilterOperation = bl_cache.BLField( + FilterOperation.PinLen1, + prop_ui=True, ) # Dimension Selection @@ -68,49 +137,26 @@ class FilterMathNode(base.MaxwellSimNode): #################### @property def data_info(self) -> ct.InfoFlow | None: - info = self._compute_input('Data', kind=ct.FlowKind.Info) + info = self._compute_input('Expr', kind=ct.FlowKind.Info) if not ct.FlowSignal.check(info): return info return None #################### - # - Operation Search - #################### - def search_operations(self) -> list[tuple[str, str, str]]: - items = [] - if self.active_socket_set == 'Interpret': - items += [ - ('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'), - ('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'), - ] - elif self.active_socket_set == 'Dimensions': - items += [ - ('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'), - ( - 'PIN', - 'pinₐ ≈v', - 'Remove a len(n) dimension by selecting an index', - ), - ('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'), - ] - - return [(*item, '', i) for i, item in enumerate(items)] - - #################### - # - Dimensions Search + # - Search Dimensions #################### def search_dims(self) -> list[ct.BLEnumElement]: if self.data_info is None: return [] - if self.operation == 'PIN_LEN_ONE': + if self.operation == FilterOperation.PinLen1: dims = [ (dim_name, dim_name, f'Dimension "{dim_name}" of length 1') for dim_name in self.data_info.dim_names if self.data_info.dim_lens[dim_name] == 1 ] - elif self.operation in ['PIN', 'SWAP']: + elif self.operation in [FilterOperation.Pin, FilterOperation.Swap]: dims = [ (dim_name, dim_name, f'Dimension "{dim_name}"') for dim_name in self.data_info.dim_names @@ -124,12 +170,13 @@ class FilterMathNode(base.MaxwellSimNode): # - UI #################### def draw_label(self): + FO = FilterOperation labels = { - 'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)', - 'PIN': lambda: f'Filter: Pin {self.dim_0}', - 'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}', - 'DIM_TO_VEC': lambda: 'Filter: -> Vector', - 'DIMS_TO_MAT': lambda: 'Filter: -> Matrix', + FO.PinLen1: lambda: f'Filter: Pin {self.dim_0} (len=1)', + FO.Pin: lambda: f'Filter: Pin {self.dim_0}', + FO.Swap: lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}', + FO.DimToVec: lambda: 'Filter: -> Vector', + FO.DimsToMat: lambda: 'Filter: -> Matrix', } if (label := labels.get(self.operation)) is not None: @@ -141,10 +188,10 @@ class FilterMathNode(base.MaxwellSimNode): layout.prop(self, self.blfields['operation'], text='') if self.active_socket_set == 'Dimensions': - if self.operation in ['PIN_LEN_ONE', 'PIN']: + if self.operation in [FilterOperation.PinLen1, FilterOperation.Pin]: layout.prop(self, self.blfields['dim_0'], text='') - if self.operation == 'SWAP': + if self.operation == FilterOperation.Swap: row = layout.row(align=True) row.prop(self, self.blfields['dim_0'], text='') row.prop(self, self.blfields['dim_1'], text='') @@ -152,215 +199,199 @@ class FilterMathNode(base.MaxwellSimNode): #################### # - Events #################### - @events.on_value_changed( - prop_name='active_socket_set', - run_on_init=True, - ) - def on_socket_set_changed(self): - self.operation = bl_cache.Signal.ResetEnumItems - @events.on_value_changed( # Trigger - socket_name='Data', - prop_name={'active_socket_set', 'operation'}, + socket_name='Expr', + prop_name={'operation'}, run_on_init=True, - # Loaded - props={'operation'}, ) - def on_any_change(self, props: dict) -> None: + def on_input_changed(self) -> None: self.dim_0 = bl_cache.Signal.ResetEnumItems self.dim_1 = bl_cache.Signal.ResetEnumItems @events.on_value_changed( - socket_name='Data', + # Trigger + socket_name='Expr', prop_name={'dim_0', 'dim_1', 'operation'}, - ## run_on_init: Implicitly triggered. + run_on_init=True, + # Loaded props={'operation', 'dim_0', 'dim_1'}, - input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Info}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': ct.FlowKind.Info}, ) - def on_dim_change(self, props: dict, input_sockets: dict): - has_data = not ct.FlowSignal.check(input_sockets['Data']) - if not has_data: + def on_pin_changed(self, props: dict, input_sockets: dict): + info = input_sockets['Expr'] + has_info = not ct.FlowSignal.check(info) + if not has_info: return # "Dimensions"|"PIN": Add/Remove Input Socket - if props['operation'] == 'PIN' and props['dim_0'] != 'NONE': + if props['operation'] == FilterOperation.Pin and props['dim_0'] is not None: + pinned_unit = info.dim_units[props['dim_0']] + pinned_mathtype = info.dim_mathtypes[props['dim_0']] + pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit) + + wanted_mathtype = ( + spux.MathType.Complex + if pinned_mathtype == spux.MathType.Complex + and spux.MathType.Complex in pinned_physical_type.valid_mathtypes + else spux.MathType.Real + ) + # Get Current and Wanted Socket Defs current_bl_socket = self.loose_input_sockets.get('Value') - wanted_socket_def = sockets.SOCKET_DEFS[ - ct.unit_to_socket_type( - input_sockets['Data'].dim_idx[props['dim_0']].unit - ) - ] # Determine Whether to Declare New Loose Input SOcket if ( current_bl_socket is None - or sockets.SOCKET_DEFS[current_bl_socket.socket_type] - != wanted_socket_def + or current_bl_socket.shape is not None + or current_bl_socket.physical_type != pinned_physical_type + or current_bl_socket.mathtype != wanted_mathtype ): self.loose_input_sockets = { - 'Value': wanted_socket_def(), + 'Value': sockets.ExprSocketDef( + active_kind=ct.FlowKind.Value, + shape=None, + physical_type=pinned_physical_type, + mathtype=wanted_mathtype, + default_unit=pinned_unit, + ), } elif self.loose_input_sockets: self.loose_input_sockets = {} #################### - # - Compute: LazyValueFunc / Array + # - Output #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.LazyValueFunc, props={'operation', 'dim_0', 'dim_1'}, - input_sockets={'Data'}, - input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, ) - def compute_data(self, props: dict, input_sockets: dict): - lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc] - info = input_sockets['Data'][ct.FlowKind.Info] - - # Check Flow - if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]): - return ct.FlowSignal.FlowPending - - # Compute Function Arguments + def compute_lazy_value_func(self, props: dict, input_sockets: dict): operation = props['operation'] - if operation == 'NONE': - return ct.FlowSignal.FlowPending + lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc] + info = input_sockets['Expr'][ct.FlowKind.Info] - ## Dimension(s) + has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) + has_info = not ct.FlowSignal.check(info) + + # Dimension(s) dim_0 = props['dim_0'] dim_1 = props['dim_1'] - if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': - return ct.FlowSignal.FlowPending - if operation == 'SWAP' and dim_1 == 'NONE': - return ct.FlowSignal.FlowPending + if ( + has_lazy_value_func + and has_info + and operation is not None + and operation.are_dims_valid(dim_0, dim_1) + ): + axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None + axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None - ## Axis/Axes - axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None - axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None - - # Compose Output Function - filter_func = { - # Dimensions - 'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0), - 'PIN': lambda data, fixed_axis_idx: jnp.take( - data, fixed_axis_idx, axis=axis_0 - ), - 'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1), - # Interpret - 'DIM_TO_VEC': lambda data: data, - 'DIMS_TO_MAT': lambda data: data, - }[props['operation']] - - return lazy_value_func.compose_within( - filter_func, - enclosing_func_args=[int] if operation == 'PIN' else [], - supports_jax=True, - ) + return lazy_value_func.compose_within( + operation.jax_func(axis_0, axis_1), + enclosing_func_args=[int] if operation == 'PIN' else [], + supports_jax=True, + ) + return ct.FlowSignal.FlowPending @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Array, - output_sockets={'Data'}, + output_sockets={'Expr'}, output_socket_kinds={ - 'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, + 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, }, + unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, ) - def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: - lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] - params = output_sockets['Data'][ct.FlowKind.Params] + def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow: + lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc] + params = output_sockets['Expr'][ct.FlowKind.Params] - # Check Flow - if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): - return ct.FlowSignal.FlowPending + has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) + has_params = not ct.FlowSignal.check(params) - return ct.ArrayFlow( - values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs), - unit=None, - ) + if has_lazy_value_func and has_params: + unit_system = unit_systems['BlenderUnits'] + return ct.ArrayFlow( + values=lazy_value_func.func_jax( + *params.scaled_func_args(unit_system), + **params.scaled_func_kwargs(unit_system), + ), + ) + return ct.FlowSignal.FlowPending #################### - # - Compute Auxiliary: Info + # - Auxiliary: Info #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Info, props={'dim_0', 'dim_1', 'operation'}, - input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Info}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': ct.FlowKind.Info}, ) - def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: - info = input_sockets['Data'] + def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: + operation = props['operation'] + info = input_sockets['Expr'] - # Check Flow - if ct.FlowSignal.check(info): - return ct.FlowSignal.FlowPending + has_info = not ct.FlowSignal.check(info) - # Collect Information + # Dimension(s) dim_0 = props['dim_0'] dim_1 = props['dim_1'] + if ( + has_info + and operation is not None + and operation.are_dims_valid(dim_0, dim_1) + ): + return operation.transform_info(info, dim_0, dim_1) - if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': - return ct.FlowSignal.FlowPending - if props['operation'] == 'SWAP' and dim_1 == 'NONE': - return ct.FlowSignal.FlowPending - - return { - # Dimensions - 'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0), - 'PIN': lambda: info.delete_dimension(dim_0), - 'SWAP': lambda: info.swap_dimensions(dim_0, dim_1), - # Interpret - 'DIM_TO_VEC': lambda: info.shift_last_input, - 'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input, - }[props['operation']]() + return ct.FlowSignal.FlowPending #################### - # - Compute Auxiliary: Info + # - Auxiliary: Params #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Params, props={'dim_0', 'dim_1', 'operation'}, - input_sockets={'Data', 'Value'}, - input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}}, + input_sockets={'Expr', 'Value'}, + input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}}, input_sockets_optional={'Value': True}, ) - def compute_composed_params( - self, props: dict, input_sockets: dict - ) -> ct.ParamsFlow: - info = input_sockets['Data'][ct.FlowKind.Info] - params = input_sockets['Data'][ct.FlowKind.Params] + def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow: + operation = props['operation'] + info = input_sockets['Expr'][ct.FlowKind.Info] + params = input_sockets['Expr'][ct.FlowKind.Params] - # Check Flow - if any(ct.FlowSignal.check(inp) for inp in [info, params]): - return ct.FlowSignal.FlowPending + has_info = not ct.FlowSignal.check(info) + has_params = not ct.FlowSignal.check(params) - # Collect Information - ## Dimensions + # Dimension(s) dim_0 = props['dim_0'] dim_1 = props['dim_1'] + if ( + has_info + and has_params + and operation is not None + and operation.are_dims_valid(dim_0, dim_1) + ): + ## Pinned Value + pinned_value = input_sockets['Value'] + has_pinned_value = not ct.FlowSignal.check(pinned_value) - if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE': - return ct.FlowSignal.FlowPending - if props['operation'] == 'SWAP' and dim_1 == 'NONE': - return ct.FlowSignal.FlowPending + if props['operation'] == 'PIN' and has_pinned_value: + nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( + pinned_value, require_sorted=True + ) - ## Pinned Value - pinned_value = input_sockets['Value'] - has_pinned_value = not ct.FlowSignal.check(pinned_value) + return params.compose_within(enclosing_func_args=[nearest_idx_to_value]) - if props['operation'] == 'PIN' and has_pinned_value: - # Compute IDX Corresponding to Dimension Index - nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( - input_sockets['Value'], require_sorted=True - ) - - return params.compose_within(enclosing_func_args=[nearest_idx_to_value]) - - return params + return params + return ct.FlowSignal.FlowPending #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 9121cea..2f5285e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -20,6 +20,248 @@ log = logger.get(__name__) X_COMPLEX = sp.Symbol('x', complex=True) +class MapOperation(enum.StrEnum): + """Valid operations for the `MapMathNode`. + + Attributes: + UserExpr: Use a user-provided mapping expression. + Real: Compute the real part of the input. + Imag: Compute the imaginary part of the input. + Abs: Compute the absolute value of the input. + Sq: Square the input. + Sqrt: Compute the (principal) square root of the input. + InvSqrt: Compute the inverse square root of the input. + Cos: Compute the cosine of the input. + Sin: Compute the sine of the input. + Tan: Compute the tangent of the input. + Acos: Compute the inverse cosine of the input. + Asin: Compute the inverse sine of the input. + Atan: Compute the inverse tangent of the input. + Norm2: Compute the 2-norm (aka. length) of the input vector. + Det: Compute the determinant of the input matrix. + Cond: Compute the condition number of the input matrix. + NormFro: Compute the frobenius norm of the input matrix. + Rank: Compute the rank of the input matrix. + Diag: Compute the diagonal vector of the input matrix. + EigVals: Compute the eigenvalues vector of the input matrix. + SvdVals: Compute the singular values vector of the input matrix. + Inv: Compute the inverse matrix of the input matrix. + Tra: Compute the transpose matrix of the input matrix. + Qr: Compute the QR-factorized matrices of the input matrix. + Chol: Compute the Cholesky-factorized matrices of the input matrix. + Svd: Compute the SVD-factorized matrices of the input matrix. + """ + + # By User Expression + UserExpr = enum.auto() + # By Number + Real = enum.auto() + Imag = enum.auto() + Abs = enum.auto() + Sq = enum.auto() + Sqrt = enum.auto() + InvSqrt = enum.auto() + Cos = enum.auto() + Sin = enum.auto() + Tan = enum.auto() + Acos = enum.auto() + Asin = enum.auto() + Atan = enum.auto() + Sinc = enum.auto() + # By Vector + Norm2 = enum.auto() + # By Matrix + Det = enum.auto() + Cond = enum.auto() + NormFro = enum.auto() + Rank = enum.auto() + Diag = enum.auto() + EigVals = enum.auto() + SvdVals = enum.auto() + Inv = enum.auto() + Tra = enum.auto() + Qr = enum.auto() + Chol = enum.auto() + Svd = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + MO = MapOperation + return { + # By User Expression + MO.UserExpr: '*', + # By Number + MO.Real: 'ℝ(v)', + MO.Imag: 'Im(v)', + MO.Abs: '|v|', + MO.Sq: 'v²', + MO.Sqrt: '√v', + MO.InvSqrt: '1/√v', + MO.Cos: 'cos v', + MO.Sin: 'sin v', + MO.Tan: 'tan v', + MO.Acos: 'acos v', + MO.Asin: 'asin v', + MO.Atan: 'atan v', + MO.Sinc: 'sinc v', + # By Vector + MO.Norm2: '||v||₂', + # By Matrix + MO.Det: 'det V', + MO.Cond: 'κ(V)', + MO.NormFro: '||V||_F', + MO.Rank: 'rank V', + MO.Diag: 'diag V', + MO.EigVals: 'eigvals V', + MO.SvdVals: 'svdvals V', + MO.Inv: 'V⁻¹', + MO.Tra: 'Vt', + MO.Qr: 'qr V', + MO.Chol: 'chol V', + MO.Svd: 'svd V', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + MO = MapOperation + return ( + str(self), + MO.to_name(self), + MO.to_name(self), + MO.to_icon(self), + i, + ) + + @staticmethod + def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: + MO = MapOperation + # By Number + if shape is None: + return [ + MO.Real, + MO.Imag, + MO.Abs, + MO.Sq, + MO.Sqrt, + MO.InvSqrt, + MO.Cos, + MO.Sin, + MO.Tan, + MO.Acos, + MO.Asin, + MO.Atan, + MO.Sinc, + ] + + # By Vector + if len(shape) == 1: + return [ + MO.Norm2, + ] + # By Matrix + if len(shape) == 2: + return [ + MO.Det, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.Diag, + MO.EigVals, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Qr, + MO.Chol, + MO.Svd, + ] + + return [] + + def jax_func(self, user_expr_func: ct.LazyValueFuncFlow | None = None): + MO = MapOperation + if self == MO.UserExpr and user_expr_func is not None: + return lambda data: user_expr_func.func(data) + return { + # By Number + MO.Real: lambda data: jnp.real(data), + MO.Imag: lambda data: jnp.imag(data), + MO.Abs: lambda data: jnp.abs(data), + MO.Sq: lambda data: jnp.square(data), + MO.Sqrt: lambda data: jnp.sqrt(data), + MO.InvSqrt: lambda data: 1 / jnp.sqrt(data), + MO.Cos: lambda data: jnp.cos(data), + MO.Sin: lambda data: jnp.sin(data), + MO.Tan: lambda data: jnp.tan(data), + MO.Acos: lambda data: jnp.acos(data), + MO.Asin: lambda data: jnp.asin(data), + MO.Atan: lambda data: jnp.atan(data), + MO.Sinc: lambda data: jnp.sinc(data), + # By Vector + # Vector -> Number + MO.Norm2: lambda data: jnp.linalg.norm(data, ord=2, axis=-1), + # By Matrix + # Matrix -> Number + MO.Det: lambda data: jnp.linalg.det(data), + MO.Cond: lambda data: jnp.linalg.cond(data), + MO.NormFro: lambda data: jnp.linalg.matrix_norm(data, ord='fro'), + MO.Rank: lambda data: jnp.linalg.matrix_rank(data), + # Matrix -> Vec + MO.Diag: lambda data: jnp.diagonal(data, axis1=-2, axis2=-1), + MO.EigVals: lambda data: jnp.linalg.eigvals(data), + MO.SvdVals: lambda data: jnp.linalg.svdvals(data), + # Matrix -> Matrix + MO.Inv: lambda data: jnp.linalg.inv(data), + MO.Tra: lambda data: jnp.matrix_transpose(data), + # Matrix -> Matrices + MO.Qr: lambda data: jnp.linalg.qr(data), + MO.Chol: lambda data: jnp.linalg.cholesky(data), + MO.Svd: lambda data: jnp.linalg.svd(data), + }[self] + + def transform_info(self, info: ct.InfoFlow): + MO = MapOperation + return { + # By User Expression + MO.UserExpr: '*', + # By Number + MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), + MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), + MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real), + # By Vector + MO.Norm2: lambda: info.collapse_output( + collapsed_name=MO.to_name(self).replace('v', info.output_name), + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=info.output_unit, + ), + # By Matrix + MO.Det: lambda: info.collapse_output( + collapsed_name=MO.to_name(self).replace('V', info.output_name), + collapsed_mathtype=info.output_mathtype, + collapsed_unit=info.output_unit, + ), + MO.Cond: lambda: info.collapse_output( + collapsed_name=MO.to_name(self).replace('V', info.output_name), + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=None, + ), + MO.NormFro: lambda: info.collapse_output( + collapsed_name=MO.to_name(self).replace('V', info.output_name), + collapsed_mathtype=spux.MathType.Real, + collapsed_unit=info.output_unit, + ), + MO.Rank: lambda: info.collapse_output( + collapsed_name=MO.to_name(self).replace('V', info.output_name), + collapsed_mathtype=spux.MathType.Integer, + collapsed_unit=None, + ), + ## TODO: Matrix -> Vec + ## TODO: Matrix -> Matrices + }.get(self, info) + + class MapMathNode(base.MaxwellSimNode): r"""Applies a function by-structure to the data. @@ -104,120 +346,46 @@ class MapMathNode(base.MaxwellSimNode): bl_label = 'Map Math' input_sockets: typ.ClassVar = { - 'Data': sockets.DataSocketDef(format='jax'), - } - input_socket_sets: typ.ClassVar = { - 'By Element': {}, - 'By Vector': {}, - 'By Matrix': {}, - 'Expr': { - 'Mapper': sockets.ExprSocketDef( - complex_symbols=[X_COMPLEX], - default_expr=X_COMPLEX, - ), - }, + 'Expr': sockets.ExprSocketDef(), } output_sockets: typ.ClassVar = { - 'Data': sockets.DataSocketDef(format='jax'), + 'Expr': sockets.ExprSocketDef(), } #################### # - Properties #################### - operation: enum.Enum = bl_cache.BLField( + operation: MapOperation = bl_cache.BLField( prop_ui=True, enum_cb=lambda self, _: self.search_operations() ) - def search_operations(self) -> list[ct.BLEnumElement]: - if self.active_socket_set == 'By Element': - items = [ - # General - ('REAL', 'ℝ(v)', 'real(v) (by el)'), - ('IMAG', 'Im(v)', 'imag(v) (by el)'), - ('ABS', '|v|', 'abs(v) (by el)'), - ('SQ', 'v²', 'v^2 (by el)'), - ('SQRT', '√v', 'sqrt(v) (by el)'), - ('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'), - None, - # Trigonometry - ('COS', 'cos v', 'cos(v) (by el)'), - ('SIN', 'sin v', 'sin(v) (by el)'), - ('TAN', 'tan v', 'tan(v) (by el)'), - ('ACOS', 'acos v', 'acos(v) (by el)'), - ('ASIN', 'asin v', 'asin(v) (by el)'), - ('ATAN', 'atan v', 'atan(v) (by el)'), - ] - elif self.active_socket_set in 'By Vector': - items = [ - # Vector -> Number - ('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'), - ] - elif self.active_socket_set == 'By Matrix': - items = [ - # Matrix -> Number - ('DET', 'det V', 'det(V) (by Mat)'), - ('COND', 'κ(V)', 'cond(V) (by Mat)'), - ('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'), - ('RANK', 'rank V', 'rank(V) (by Mat)'), - None, - # Matrix -> Array - ('DIAG', 'diag V', 'diag(V) (by Mat)'), - ('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'), - ('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'), - None, - # Matrix -> Matrix - ('INV', 'V⁻¹', 'V^(-1) (by Mat)'), - ('TRA', 'Vt', 'V^T (by Mat)'), - None, - # Matrix -> Matrices - ('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'), - ('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'), - ('SVD', 'svd V', 'svd(V) -> U·Σ·V† (by Mat)'), - ] - elif self.active_socket_set == 'Expr': - items = [('EXPR_EL', 'By Element', 'Expression-defined (by el)')] - else: - msg = f'Active socket set {self.active_socket_set} is unknown' - raise RuntimeError(msg) + @property + def expr_output_shape(self) -> ct.InfoFlow | None: + info = self._compute_input('Expr', kind=ct.FlowKind.Info) + has_info = not ct.FlowSignal.check(info) + if has_info: + return info.output_shape - return [ - (*item, '', i) if item is not None else None for i, item in enumerate(items) - ] + return None + + output_shape: tuple[int, ...] | None = bl_cache.BLField(None) + + def search_operations(self) -> list[ct.BLEnumElement]: + if self.expr_output_shape is not None: + return [ + operation.bl_enum_element(i) + for i, operation in enumerate( + MapOperation.by_element_shape(self.expr_output_shape) + ) + ] + return [] #################### # - UI #################### def draw_label(self): - labels = { - 'REAL': 'ℝ(v)', - 'IMAG': 'Im(v)', - 'ABS': '|v|', - 'SQ': 'v²', - 'SQRT': '√v', - 'INV_SQRT': '1/√v', - 'COS': 'cos v', - 'SIN': 'sin v', - 'TAN': 'tan v', - 'ACOS': 'acos v', - 'ASIN': 'asin v', - 'ATAN': 'atan v', - 'NORM_2': '||v||₂', - 'DET': 'det V', - 'COND': 'κ(V)', - 'NORM_FRO': '||V||_F', - 'RANK': 'rank V', - 'DIAG': 'diag V', - 'EIG_VALS': 'eigvals V', - 'SVD_VALS': 'svdvals V', - 'INV': 'V⁻¹', - 'TRA': 'Vt', - 'QR': 'qr V', - 'CHOL': 'chol V', - 'SVD': 'svd V', - } - - if (label := labels.get(self.operation)) is not None: - return 'Map: ' + label + if self.operation is not None: + return 'Map: ' + MapOperation.to_name(self.operation) return self.bl_label @@ -228,106 +396,98 @@ class MapMathNode(base.MaxwellSimNode): # - Events #################### @events.on_value_changed( - prop_name='active_socket_set', + # Trigger + socket_name='Expr', run_on_init=True, ) - def on_socket_set_changed(self): - self.operation = bl_cache.Signal.ResetEnumItems + def on_input_changed(self): + if self.operation not in MapOperation.by_element_shape(self.expr_output_shape): + self.operation = bl_cache.Signal.ResetEnumItems + + @events.on_value_changed( + # Trigger + prop_name={'operation'}, + run_on_init=True, + # Loaded + props={'operation'}, + ) + def on_operation_changed(self, props: dict) -> None: + operation = props['operation'] + + # UserExpr: Add/Remove Input Socket + if operation == MapOperation.UserExpr: + current_bl_socket = self.loose_input_sockets.get('Mapper') + if current_bl_socket is None: + self.loose_input_sockets = { + 'Mapper': sockets.ExprSocketDef( + symbols={X_COMPLEX}, + default_value=X_COMPLEX, + mathtype=spux.MathType.Complex, + ), + } + + elif self.loose_input_sockets: + self.loose_input_sockets = {} #################### # - Compute: LazyValueFunc / Array #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.LazyValueFunc, - props={'active_socket_set', 'operation'}, - input_sockets={'Data', 'Mapper'}, + props={'operation'}, + input_sockets={'Expr', 'Mapper'}, input_socket_kinds={ - 'Data': ct.FlowKind.LazyValueFunc, + 'Expr': ct.FlowKind.LazyValueFunc, 'Mapper': ct.FlowKind.LazyValueFunc, }, input_sockets_optional={'Mapper': True}, ) def compute_data(self, props: dict, input_sockets: dict): - has_data = not ct.FlowSignal.check(input_sockets['Data']) - if ( - not has_data - or props['operation'] == 'NONE' - or ( - props['active_socket_set'] == 'Expr' - and ct.FlowSignal.check(input_sockets['Mapper']) - ) - ): - return ct.FlowSignal.FlowPending + operation = props['operation'] + expr = input_sockets['Expr'] + mapper = input_sockets['Mapper'] - mapping_func: typ.Callable[[jax.Array], jax.Array] = { - 'By Element': { - 'REAL': lambda data: jnp.real(data), - 'IMAG': lambda data: jnp.imag(data), - 'ABS': lambda data: jnp.abs(data), - 'SQ': lambda data: jnp.square(data), - 'SQRT': lambda data: jnp.sqrt(data), - 'INV_SQRT': lambda data: 1 / jnp.sqrt(data), - 'COS': lambda data: jnp.cos(data), - 'SIN': lambda data: jnp.sin(data), - 'TAN': lambda data: jnp.tan(data), - 'ACOS': lambda data: jnp.acos(data), - 'ASIN': lambda data: jnp.asin(data), - 'ATAN': lambda data: jnp.atan(data), - 'SINC': lambda data: jnp.sinc(data), - }, - 'By Vector': { - 'NORM_2': lambda data: jnp.linalg.norm(data, ord=2, axis=-1), - }, - 'By Matrix': { - # Matrix -> Number - 'DET': lambda data: jnp.linalg.det(data), - 'COND': lambda data: jnp.linalg.cond(data), - 'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'), - 'RANK': lambda data: jnp.linalg.matrix_rank(data), - # Matrix -> Vec - 'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1), - 'EIG_VALS': lambda data: jnp.linalg.eigvals(data), - 'SVD_VALS': lambda data: jnp.linalg.svdvals(data), - # Matrix -> Matrix - 'INV': lambda data: jnp.linalg.inv(data), - 'TRA': lambda data: jnp.matrix_transpose(data), - # Matrix -> Matrices - 'QR': lambda data: jnp.linalg.qr(data), - 'CHOL': lambda data: jnp.linalg.cholesky(data), - 'SVD': lambda data: jnp.linalg.svd(data), - }, - 'Expr': { - 'EXPR_EL': lambda data: input_sockets['Mapper'].func(data), - }, - }[props['active_socket_set']][props['operation']] + has_expr = not ct.FlowSignal.check(expr) + has_mapper = not ct.FlowSignal.check(expr) - # Compose w/Lazy Root Function Data - return input_sockets['Data'].compose_within( - mapping_func, - supports_jax=True, - ) + if has_expr and operation is not None: + if not has_mapper: + return expr.compose_within( + operation.jax_func(), + supports_jax=True, + ) + if operation == MapOperation.UserExpr and has_mapper: + return expr.compose_within( + operation.jax_func(user_expr_func=mapper), + supports_jax=True, + ) + return ct.FlowSignal.FlowPending @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Array, - output_sockets={'Data'}, + output_sockets={'Expr'}, output_socket_kinds={ - 'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, + 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, }, + unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, ) - def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: - lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] - params = output_sockets['Data'][ct.FlowKind.Params] + def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow: + lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc] + params = output_sockets['Expr'][ct.FlowKind.Params] - if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): + has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) + has_params = not ct.FlowSignal.check(params) + + if has_lazy_value_func and has_params: + unit_system = unit_systems['BlenderUnits'] return ct.ArrayFlow( values=lazy_value_func.func_jax( - *params.func_args, **params.func_kwargs + *params.scaled_func_args(unit_system), + **params.scaled_func_kwargs(unit_system), ), - unit=None, ) - return ct.FlowSignal.FlowPending #################### @@ -341,59 +501,15 @@ class MapMathNode(base.MaxwellSimNode): input_socket_kinds={'Data': ct.FlowKind.Info}, ) def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: - info = input_sockets['Data'] - if ct.FlowSignal.check(info): - return ct.FlowSignal.FlowPending + operation = props['operation'] + info = input_sockets['Expr'] - # Complex -> Real - if props['active_socket_set'] == 'By Element' and props['operation'] in [ - 'REAL', - 'IMAG', - 'ABS', - ]: - return info.set_output_mathtype(spux.MathType.Real) + has_info = not ct.FlowSignal.check(info) - if props['active_socket_set'] == 'By Vector' and props['operation'] in [ - 'NORM_2' - ]: - return { - 'NORM_2': lambda: info.collapse_output( - collapsed_name=f'||{info.output_name}||₂', - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=info.output_unit, - ) - }[props['operation']]() + if has_info and operation is not None: + return operation.transform_info(info) - if props['active_socket_set'] == 'By Matrix' and props['operation'] in [ - 'DET', - 'COND', - 'NORM_FRO', - 'RANK', - ]: - return { - 'DET': lambda: info.collapse_output( - collapsed_name=f'det {info.output_name}', - collapsed_mathtype=info.output_mathtype, - collapsed_unit=info.output_unit, - ), - 'COND': lambda: info.collapse_output( - collapsed_name=f'κ({info.output_name})', - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=None, - ), - 'NORM_FRO': lambda: info.collapse_output( - collapsed_name=f'||({info.output_name}||_F', - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=info.output_unit, - ), - 'RANK': lambda: info.collapse_output( - collapsed_name=f'rank {info.output_name}', - collapsed_mathtype=spux.MathType.Integer, - collapsed_unit=None, - ), - }[props['operation']]() - - return info + return ct.FlowSignal.FlowPending @events.computes_output_socket( 'Data', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py index 87ef0ae..8b5794d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -14,6 +14,29 @@ from ... import base, events log = logger.get(__name__) +FUNCS = { + 'ADD': lambda exprs: exprs[0] + exprs[1], + 'SUB': lambda exprs: exprs[0] - exprs[1], + 'MUL': lambda exprs: exprs[0] * exprs[1], + 'DIV': lambda exprs: exprs[0] / exprs[1], + 'POW': lambda exprs: exprs[0] ** exprs[1], +} + +SP_FUNCS = FUNCS +JAX_FUNCS = FUNCS | { + # Number | * + 'ATAN2': lambda exprs: jnp.atan2(exprs[1], exprs[0]), + # Vector | Vector + 'VEC_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]), + 'CROSS': lambda exprs: jnp.cross(exprs[0], exprs[1]), + # Matrix | Vector + 'MAT_VEC_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]), + 'LIN_SOLVE': lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), + 'LSQ_SOLVE': lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), + # Matrix | Matrix + 'MAT_MAT_DOT': lambda exprs: jnp.matmul(exprs[0], exprs[1]), +} + class OperateMathNode(base.MaxwellSimNode): r"""Applies a function that depends on two inputs. @@ -28,40 +51,12 @@ class OperateMathNode(base.MaxwellSimNode): node_type = ct.NodeType.OperateMath bl_label = 'Operate Math' - input_socket_sets: typ.ClassVar = { - 'Expr | Expr': { - 'Expr L': sockets.ExprSocketDef(), - 'Expr R': sockets.ExprSocketDef(), - }, - 'Data | Data': { - 'Data L': sockets.DataSocketDef( - format='jax', default_show_info_columns=False - ), - 'Data R': sockets.DataSocketDef( - format='jax', default_show_info_columns=False - ), - }, - 'Expr | Data': { - 'Expr L': sockets.ExprSocketDef(), - 'Data R': sockets.DataSocketDef( - format='jax', default_show_info_columns=False - ), - }, + input_sockets: typ.ClassVar = { + 'Expr L': sockets.ExprSocketDef(show_info_columns=False), + 'Expr R': sockets.ExprSocketDef(show_info_columns=False), } - output_socket_sets: typ.ClassVar = { - 'Expr | Expr': { - 'Expr': sockets.ExprSocketDef(), - }, - 'Data | Data': { - 'Data': sockets.DataSocketDef( - format='jax', default_show_info_columns=False - ), - }, - 'Expr | Data': { - 'Data': sockets.DataSocketDef( - format='jax', default_show_info_columns=False - ), - }, + output_sockets: typ.ClassVar = { + 'Expr': sockets.ExprSocketDef(), } #################### @@ -77,15 +72,15 @@ class OperateMathNode(base.MaxwellSimNode): def search_categories(self) -> list[ct.BLEnumElement]: """Deduce and return a list of valid categories for the current socket set and input data.""" - data_l_info = self._compute_input( - 'Data L', kind=ct.FlowKind.Info, optional=True + expr_l_info = self._compute_input( + 'Expr L', kind=ct.FlowKind.Info, optional=True ) - data_r_info = self._compute_input( - 'Data R', kind=ct.FlowKind.Info, optional=True + expr_r_info = self._compute_input( + 'Expr R', kind=ct.FlowKind.Info, optional=True ) - has_data_l_info = not ct.FlowSignal.check(data_l_info) - has_data_r_info = not ct.FlowSignal.check(data_r_info) + has_expr_l_info = not ct.FlowSignal.check(expr_l_info) + has_expr_r_info = not ct.FlowSignal.check(expr_r_info) # Categories by Socket Set NUMBER_NUMBER = ( @@ -120,64 +115,45 @@ class OperateMathNode(base.MaxwellSimNode): ) categories = [] - ## Expr | Expr - if self.active_socket_set == 'Expr | Expr': - return [NUMBER_NUMBER] - - ## Data | Data - if ( - self.active_socket_set == 'Data | Data' - and has_data_l_info - and has_data_r_info - ): + if has_expr_l_info and has_expr_r_info: # Check Valid Broadcasting ## Number | Number - if data_l_info.output_shape is None and data_r_info.output_shape is None: + if expr_l_info.output_shape is None and expr_r_info.output_shape is None: categories = [NUMBER_NUMBER] ## Number | Vector elif ( - data_l_info.output_shape is None and len(data_r_info.output_shape) == 1 + expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 1 ): categories = [NUMBER_VECTOR] ## Number | Matrix elif ( - data_l_info.output_shape is None and len(data_r_info.output_shape) == 2 - ): # noqa: PLR2004 + expr_l_info.output_shape is None and len(expr_r_info.output_shape) == 2 + ): categories = [NUMBER_MATRIX] ## Vector | Vector elif ( - len(data_l_info.output_shape) == 1 - and len(data_r_info.output_shape) == 1 + len(expr_l_info.output_shape) == 1 + and len(expr_r_info.output_shape) == 1 ): categories = [VECTOR_VECTOR] ## Matrix | Vector elif ( - len(data_l_info.output_shape) == 2 # noqa: PLR2004 - and len(data_r_info.output_shape) == 1 + len(expr_l_info.output_shape) == 2 # noqa: PLR2004 + and len(expr_r_info.output_shape) == 1 ): categories = [MATRIX_VECTOR] ## Matrix | Matrix elif ( - len(data_l_info.output_shape) == 2 # noqa: PLR2004 - and len(data_r_info.output_shape) == 2 # noqa: PLR2004 + len(expr_l_info.output_shape) == 2 # noqa: PLR2004 + and len(expr_r_info.output_shape) == 2 # noqa: PLR2004 ): categories = [MATRIX_MATRIX] - ## Expr | Data - if self.active_socket_set == 'Expr | Data' and has_data_r_info: - if data_r_info.output_shape is None: - categories = [NUMBER_NUMBER] - else: - categories = { - 1: [NUMBER_NUMBER, NUMBER_VECTOR], - 2: [NUMBER_NUMBER, NUMBER_MATRIX], - }[len(data_r_info.output_shape)] - return [ (*category, '', i) if category is not None else None for i, category in enumerate(categories) @@ -248,11 +224,10 @@ class OperateMathNode(base.MaxwellSimNode): #################### @events.on_value_changed( # Trigger - socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'}, - prop_name='active_socket_set', + socket_name={'Expr L', 'Expr R'}, run_on_init=True, ) - def on_socket_set_changed(self) -> None: + def on_socket_changed(self) -> None: # Recompute Valid Categories self.category = bl_cache.Signal.ResetEnumItems self.operation = bl_cache.Signal.ResetEnumItems @@ -272,224 +247,135 @@ class OperateMathNode(base.MaxwellSimNode): kind=ct.FlowKind.Value, props={'operation'}, input_sockets={'Expr L', 'Expr R'}, + input_socket_kinds={ + 'Expr L': ct.FlowKind.Value, + 'Expr R': ct.FlowKind.Value, + }, ) - def compute_expr(self, props: dict, input_sockets: dict): + def compute_value(self, props: dict, input_sockets: dict): + operation = props['operation'] expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] - return { - 'ADD': lambda: expr_l + expr_r, - 'SUB': lambda: expr_l - expr_r, - 'MUL': lambda: expr_l * expr_r, - 'DIV': lambda: expr_l / expr_r, - 'POW': lambda: expr_l**expr_r, - 'ATAN2': lambda: sp.atan2(expr_r, expr_l), - }[props['operation']]() + has_expr_l_value = not ct.FlowSignal.check(expr_l) + has_expr_r_value = not ct.FlowSignal.check(expr_r) + + if has_expr_l_value and has_expr_r_value and operation is not None: + return SP_FUNCS[operation]([expr_l, expr_r]) + + return ct.Flowsignal.FlowPending @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.LazyValueFunc, props={'operation'}, - input_sockets={'Data L', 'Data R'}, + input_sockets={'Expr L', 'Expr R'}, input_socket_kinds={ - 'Data L': ct.FlowKind.LazyValueFunc, - 'Data R': ct.FlowKind.LazyValueFunc, - }, - input_sockets_optional={ - 'Data L': True, - 'Data R': True, + 'Expr L': ct.FlowKind.LazyValueFunc, + 'Expr R': ct.FlowKind.LazyValueFunc, }, ) - def compute_data(self, props: dict, input_sockets: dict): - data_l = input_sockets['Data L'] - data_r = input_sockets['Data R'] - has_data_l = not ct.FlowSignal.check(data_l) + def compose_func(self, props: dict, input_sockets: dict): + operation = props['operation'] + if operation is None: + return ct.FlowSignal.FlowPending - mapping_func = { - # Number | * - 'ADD': lambda datas: datas[0] + datas[1], - 'SUB': lambda datas: datas[0] - datas[1], - 'MUL': lambda datas: datas[0] * datas[1], - 'DIV': lambda datas: datas[0] / datas[1], - 'POW': lambda datas: datas[0] ** datas[1], - 'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]), - # Vector | Vector - 'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), - 'CROSS': lambda datas: jnp.cross(datas[0], datas[1]), - # Matrix | Vector - 'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), - 'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]), - 'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]), - # Matrix | Matrix - 'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]), - }[props['operation']] + expr_l = input_sockets['Expr L'] + expr_r = input_sockets['Expr R'] - # Compose by Socket Set - ## Data | Data - if has_data_l: - return (data_l | data_r).compose_within( - mapping_func, + has_expr_l = not ct.FlowSignal.check(expr_l) + has_expr_r = not ct.FlowSignal.check(expr_r) + + if has_expr_l and has_expr_r: + return (expr_l | expr_r).compose_within( + JAX_FUNCS[operation], supports_jax=True, ) - - ## Expr | Data - expr_l_lazy_value_func = ct.LazyValueFuncFlow( - func=lambda expr_l_value: expr_l_value, - func_args=[typ.Any], - supports_jax=True, - ) - return (expr_l_lazy_value_func | data_r).compose_within( - mapping_func, - supports_jax=True, - ) + return ct.FlowSignal.FlowPending @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Array, - output_sockets={'Data'}, + output_sockets={'Expr'}, output_socket_kinds={ - 'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, + 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, }, + unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, ) - def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: - lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] - params = output_sockets['Data'][ct.FlowKind.Params] + def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow: + lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc] + params = output_sockets['Expr'][ct.FlowKind.Params] has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) has_params = not ct.FlowSignal.check(params) if has_lazy_value_func and has_params: + unit_system = unit_systems['BlenderUnits'] return ct.ArrayFlow( values=lazy_value_func.func_jax( - *params.func_args, **params.func_kwargs + *params.scaled_func_args(unit_system), + **params.scaled_func_kwargs(unit_system), ), unit=None, ) return ct.FlowSignal.FlowPending - #################### - # - Auxiliary: Params - #################### - @events.computes_output_socket( - 'Data', - kind=ct.FlowKind.Params, - props={'operation'}, - input_sockets={'Expr L', 'Data L', 'Data R'}, - input_socket_kinds={ - 'Expr L': ct.FlowKind.Value, - 'Data L': {ct.FlowKind.Info, ct.FlowKind.Params}, - 'Data R': {ct.FlowKind.Info, ct.FlowKind.Params}, - }, - input_sockets_optional={ - 'Expr L': True, - 'Data L': True, - 'Data R': True, - }, - ) - def compute_data_params( - self, props, input_sockets - ) -> ct.ParamsFlow | ct.FlowSignal: - expr_l = input_sockets['Expr L'] - data_l_info = input_sockets['Data L'][ct.FlowKind.Info] - data_l_params = input_sockets['Data L'][ct.FlowKind.Params] - data_r_info = input_sockets['Data R'][ct.FlowKind.Info] - data_r_params = input_sockets['Data R'][ct.FlowKind.Params] - - has_expr_l = not ct.FlowSignal.check(expr_l) - has_data_l_info = not ct.FlowSignal.check(data_l_info) - has_data_l_params = not ct.FlowSignal.check(data_l_params) - has_data_r_info = not ct.FlowSignal.check(data_r_info) - has_data_r_params = not ct.FlowSignal.check(data_r_params) - - # Compose by Socket Set - ## Data | Data - if ( - has_data_l_info - and has_data_l_params - and has_data_r_info - and has_data_r_params - ): - return data_l_params | data_r_params - - ## Expr | Data - if has_expr_l and has_data_r_info and has_data_r_params: - operation = props['operation'] - data_unit = data_r_info.output_unit - - # By Operation - ## Add/Sub: Scale to Output Unit - if operation in ['ADD', 'SUB', 'MUL', 'DIV']: - if not spux.uses_units(expr_l): - value = spux.sympy_to_python(expr_l) - else: - value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) - - return data_r_params.compose_within( - enclosing_func_args=[value], - ) - - ## Pow: Doesn't Exist (?) - ## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function - if operation == 'POW': - return ct.FlowSignal.FlowPending - - ## atan2(): Only Length - ## -> Implicitly presume that Data L/R use length units. - if operation == 'ATAN2': - if not spux.uses_units(expr_l): - value = spux.sympy_to_python(expr_l) - else: - value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) - - return data_r_params.compose_within( - enclosing_func_args=[value], - ) - - return data_r_params.compose_within( - enclosing_func_args=[ - spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit)) - ] - ) - - return ct.FlowSignal.FlowPending - #################### # - Auxiliary: Info #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Info, - input_sockets={'Expr L', 'Data L', 'Data R'}, + props={'operation'}, + input_sockets={'Expr L', 'Expr R'}, input_socket_kinds={ - 'Expr L': ct.FlowKind.Value, - 'Data L': ct.FlowKind.Info, - 'Data R': ct.FlowKind.Info, - }, - input_sockets_optional={ - 'Expr L': True, - 'Data L': True, - 'Data R': True, + 'Expr L': ct.FlowKind.Info, + 'Expr R': ct.FlowKind.Info, }, ) - def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow: - expr_l = input_sockets['Expr L'] - data_l_info = input_sockets['Data L'] - data_r_info = input_sockets['Data R'] + def compute_info(self, props, input_sockets) -> ct.InfoFlow: + operation = props['operation'] + info_l = input_sockets['Expr L'] + info_r = input_sockets['Expr R'] - has_expr_l = not ct.FlowSignal.check(expr_l) - has_data_l_info = not ct.FlowSignal.check(data_l_info) - has_data_r_info = not ct.FlowSignal.check(data_r_info) + has_info_l = not ct.FlowSignal.check(info_l) + has_info_r = not ct.FlowSignal.check(info_r) - # Info by Socket Set - ## Data | Data - if has_data_l_info and has_data_r_info: - return data_r_info + # Return Info of RHS + ## -> Fundamentall, this is why 'category' only has the given options. + ## -> Via 'category', we enforce that the operated-on structure is always RHS. + ## -> That makes it super duper easy to track info changes. + if has_info_l and has_info_r and operation is not None: + return info_r - ## Expr | Data - if has_expr_l and has_data_r_info: - return data_r_info + return ct.FlowSignal.FlowPending + #################### + # - Auxiliary: Params + #################### + @events.computes_output_socket( + 'Expr', + kind=ct.FlowKind.Params, + props={'operation'}, + input_sockets={'Expr L', 'Expr R'}, + input_socket_kinds={ + 'Expr L': ct.FlowKind.Params, + 'Expr R': ct.FlowKind.Params, + }, + ) + def compute_params( + self, props, input_sockets + ) -> ct.ParamsFlow | ct.FlowSignal: + operation = props['operation'] + params_l = input_sockets['Expr L'] + params_r = input_sockets['Expr R'] + + has_params_l = not ct.FlowSignal.check(params_l) + has_params_r = not ct.FlowSignal.check(params_r) + + if has_params_l and has_params_r and operation is not None: + return params_l | params_r return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 76a191c..0e008c0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -147,7 +147,7 @@ class VizTarget(enum.StrEnum): @staticmethod def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: return { - 'NONE': [], + None: [], VizMode.Hist1D: [VizTarget.Plot2D], VizMode.BoxPlot1D: [VizTarget.Plot2D], VizMode.Curve2D: [VizTarget.Plot2D], @@ -192,7 +192,7 @@ class VizNode(base.MaxwellSimNode): # - Sockets #################### input_sockets: typ.ClassVar = { - 'Data': sockets.DataSocketDef(format='jax'), + 'Expr': sockets.ExprSocketDef(), } output_sockets: typ.ClassVar = { 'Preview': sockets.AnySocketDef(), @@ -222,7 +222,7 @@ class VizNode(base.MaxwellSimNode): ##################### @property def data_info(self) -> ct.InfoFlow: - return self._compute_input('Data', kind=ct.FlowKind.Info) + return self._compute_input('Expr', kind=ct.FlowKind.Info) def search_modes(self) -> list[ct.BLEnumElement]: if not ct.FlowSignal.check(self.data_info): @@ -243,7 +243,7 @@ class VizNode(base.MaxwellSimNode): ## - Target Searcher ##################### def search_targets(self) -> list[ct.BLEnumElement]: - if self.viz_mode != 'NONE': + if self.viz_mode is not None: return [ ( viz_target, @@ -271,15 +271,15 @@ class VizNode(base.MaxwellSimNode): # - Events #################### @events.on_value_changed( - socket_name='Data', - input_sockets={'Data'}, + socket_name='Expr', + input_sockets={'Expr'}, run_on_init=True, - input_socket_kinds={'Data': ct.FlowKind.Info}, - input_sockets_optional={'Data': True}, + input_socket_kinds={'Expr': ct.FlowKind.Info}, + input_sockets_optional={'Expr': True}, ) def on_any_changed(self, input_sockets: dict): if not ct.FlowSignal.check_single( - input_sockets['Data'], ct.FlowSignal.FlowPending + input_sockets['Expr'], ct.FlowSignal.FlowPending ): self.viz_mode = bl_cache.Signal.ResetEnumItems self.viz_target = bl_cache.Signal.ResetEnumItems @@ -297,8 +297,8 @@ class VizNode(base.MaxwellSimNode): @events.on_show_plot( managed_objs={'plot'}, props={'viz_mode', 'viz_target', 'colormap'}, - input_sockets={'Data'}, - input_socket_kinds={'Data': {ct.FlowKind.Array, ct.FlowKind.Info}}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': {ct.FlowKind.Array, ct.FlowKind.Info}}, stop_propagation=True, ) def on_show_plot( @@ -308,14 +308,14 @@ class VizNode(base.MaxwellSimNode): props: dict, ): # Retrieve Inputs - array_flow = input_sockets['Data'][ct.FlowKind.Array] - info = input_sockets['Data'][ct.FlowKind.Info] + array_flow = input_sockets['Expr'][ct.FlowKind.Array] + info = input_sockets['Expr'][ct.FlowKind.Info] # Check Flow if ( any(ct.FlowSignal.check(inp) for inp in [array_flow, info]) - or props['viz_mode'] == 'NONE' - or props['viz_target'] == 'NONE' + or props['viz_mode'] is None + or props['viz_target'] is None ): return diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py index 53c55f2..807b415 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py @@ -1,7 +1,7 @@ from . import ( constants, file_importers, - unit_system, + #unit_system, wave_constant, web_importers, ) @@ -10,14 +10,14 @@ from . import ( BL_REGISTER = [ *wave_constant.BL_REGISTER, - *unit_system.BL_REGISTER, + #*unit_system.BL_REGISTER, *constants.BL_REGISTER, *web_importers.BL_REGISTER, *file_importers.BL_REGISTER, ] BL_NODES = { **wave_constant.BL_NODES, - **unit_system.BL_NODES, + #**unit_system.BL_NODES, **constants.BL_NODES, **web_importers.BL_NODES, **file_importers.BL_NODES, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/number_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/number_constant.py index 8392a27..47b8fd8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/number_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/number_constant.py @@ -1,32 +1,65 @@ +import enum import typing as typ +import bpy + +from blender_maxwell.utils import bl_cache +from blender_maxwell.utils import extra_sympy_units as spux + from .... import contracts as ct from .... import sockets from ... import base, events class NumberConstantNode(base.MaxwellSimNode): + """A unitless number of configurable math type ex. integer, real, etc. . + + Attributes: + mathtype: The math type to specify the number as. + """ + node_type = ct.NodeType.NumberConstant bl_label = 'Numerical Constant' - input_socket_sets: typ.ClassVar = { - 'Integer': { - 'Value': sockets.IntegerNumberSocketDef(), - }, - 'Rational': { - 'Value': sockets.RationalNumberSocketDef(), - }, - 'Real': { - 'Value': sockets.RealNumberSocketDef(), - }, - 'Complex': { - 'Value': sockets.ComplexNumberSocketDef(), - }, + input_sockets: typ.ClassVar = { + 'Value': sockets.ExprSocketDef(), + } + output_sockets: typ.ClassVar = { + 'Value': sockets.ExprSocketDef(), } - output_socket_sets = input_socket_sets #################### - # - Callbacks + # - Properties + #################### + mathtype: spux.MathType = bl_cache.BLField( + spux.MathType.Integer, + prop_ui=True, + ) + + size: spux.NumberSize1D = bl_cache.BLField( + spux.NumberSize1D.Scalar, + prop_ui=True, + ) + + #################### + # - UI + #################### + def draw_value(self, col: bpy.types.UILayout) -> None: + row = col.row(align=True) + row.prop(self, self.blfields['mathtype'], text='') + row.prop(self, self.blfields['size'], text='') + + #################### + # - Events + #################### + @events.on_value_changed(prop_name={'mathtype', 'size'}, props={'mathtype', 'size'}) + def on_mathtype_size_changed(self, props) -> None: + """Change the input/output expression sockets to match the mathtype declared in the node.""" + self.inputs['Value'].mathtype = props['mathtype'] + self.inputs['Value'].shape = props['mathtype'].shape + + #################### + # - FlowKind #################### @events.computes_output_socket('Value', input_sockets={'Value'}) def compute_value(self, input_sockets) -> typ.Any: diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py index 42404f3..bf62f66 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py @@ -1,54 +1,91 @@ +import enum import typing as typ import sympy as sp +from blender_maxwell.utils import bl_cache +from blender_maxwell.utils import extra_sympy_units as spux + from .... import contracts, sockets from ... import base, events class PhysicalConstantNode(base.MaxwellSimTreeNode): + """A number of configurable unit dimension, ex. time, length, etc. . + + Attributes: + physical_type: The physical type to specify. + size: The size of the physical type, if it can be a vector. + """ + node_type = contracts.NodeType.PhysicalConstant bl_label = 'Physical Constant' - input_socket_sets: typ.ClassVar = { - 'time': { - 'value': sockets.PhysicalTimeSocketDef( - label='Time', - ), - }, - 'angle': { - 'value': sockets.PhysicalAngleSocketDef( - label='Angle', - ), - }, - 'length': { - 'value': sockets.PhysicalLengthSocketDef( - label='Length', - ), - }, - 'area': { - 'value': sockets.PhysicalAreaSocketDef( - label='Area', - ), - }, - 'volume': { - 'value': sockets.PhysicalVolumeSocketDef( - label='Volume', - ), - }, - 'point_3d': { - 'value': sockets.PhysicalPoint3DSocketDef( - label='3D Point', - ), - }, - 'size_3d': { - 'value': sockets.PhysicalSize3DSocketDef( - label='3D Size', - ), - }, - ## I got bored so maybe the rest later + input_sockets: typ.ClassVar = { + 'Value': sockets.ExprSocketDef(), } - output_socket_sets: typ.ClassVar = input_socket_sets + output_sockets: typ.ClassVar = { + 'Value': sockets.ExprSocketDef(), + } + + #################### + # - Properties + #################### + physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.Time, + prop_ui=True, + ) + + mathtype: enum.Enum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_mathtypes(), + prop_ui=True, + ) + + size: enum.Enum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_sizes(), + prop_ui=True, + ) + + #################### + # - Searchers + #################### + def search_mathtypes(self): + return [ + mathtype.bl_enum_element(i) + for i, mathtype in enumerate(self.physical_type.valid_mathtypes) + ] + + def search_sizes(self): + return [ + spux.NumberSize1D.from_shape(shape).bl_enum_element(i) + for i, shape in enumerate(self.physical_type.valid_shapes) + if spux.NumberSize1D.supports_shape(shape) + ] + + #################### + # - Events + #################### + @events.on_value_changed( + prop_name={'physical_type', 'mathtype', 'size'}, + props={'physical_type', 'mathtype', 'size'}, + ) + def on_mathtype_or_size_changed(self, props) -> None: + """Change the input/output expression sockets to match the mathtype and size declared in the node.""" + shape = spux.NumberSize1D(props['size']).shape + + # Set Input Socket Physical Type + if self.inputs['Value'].physical_type != props['physical_type']: + self.inputs['Value'].physical_type = props['physical_type'] + self.search_mathtypes = bl_cache.Signal.ResetEnumItems + self.search_sizes = bl_cache.Signal.ResetEnumItems + + # Set Input Socket Math Type + if self.inputs['Value'].mathtype != props['mathtype']: + self.inputs['Value'].mathtype = props['mathtype'] + + # Set Input Socket Shape + if self.inputs['Value'].shape != shape: + self.inputs['Value'].shape = shape #################### # - Callbacks diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py index 7520db8..749a2ee 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py @@ -33,8 +33,7 @@ class WaveConstantNode(base.MaxwellSimNode): input_socket_sets: typ.ClassVar = { 'Wavelength': { 'WL': sockets.ExprSocketDef( - active_kind=ct.FlowKind.Value, - unit_dimension=spux.unit_dims.length, + physical_type=spux.PhysicalType.Length, # Defaults default_unit=spu.nm, default_value=500, @@ -46,7 +45,7 @@ class WaveConstantNode(base.MaxwellSimNode): 'Frequency': { 'Freq': sockets.ExprSocketDef( active_kind=ct.FlowKind.Value, - unit_dimension=spux.unit_dims.frequency, + physical_type=spux.PhysicalType.Freq, # Defaults default_unit=spux.THz, default_value=1, @@ -59,11 +58,11 @@ class WaveConstantNode(base.MaxwellSimNode): output_sockets: typ.ClassVar = { 'WL': sockets.ExprSocketDef( active_kind=ct.FlowKind.Value, - unit_dimension=spux.unit_dims.length, + unit_dimension=spux.Dims.length, ), 'Freq': sockets.ExprSocketDef( active_kind=ct.FlowKind.Value, - unit_dimension=spux.unit_dims.frequency, + unit_dimension=spux.Dims.frequency, ), } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py index 8356744..99029c9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py @@ -1,6 +1,7 @@ import typing as typ import sympy as sp +import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes @@ -25,22 +26,44 @@ class EHFieldMonitorNode(base.MaxwellSimNode): # - Sockets #################### input_sockets: typ.ClassVar = { - 'Center': sockets.PhysicalPoint3DSocketDef(), - 'Size': sockets.PhysicalSize3DSocketDef(), - 'Samples/Space': sockets.Integer3DVectorSocketDef( - default_value=sp.Matrix([10, 10, 10]) + 'Center': sockets.ExprSocketDef( + shape=(3,), + physical_type=spux.PhysicalType.Length, ), + 'Size': sockets.ExprSocketDef( + shape=(3,), + physical_type=spux.PhysicalType.Length, + ), + 'Spatial Subdivs': sockets.ExprSocketDef( + shape=(3,), + mathtype=spux.MathType.Integer, + default_value=sp.Matrix([10, 10, 10]), + ), + ## TODO: Pass a grid instead of size and resolution + ## TODO: 1D (line), 2D (plane), 3D modes } input_socket_sets: typ.ClassVar = { 'Freq Domain': { - 'Freqs': sockets.PhysicalFreqSocketDef( - is_array=True, + 'Freqs': sockets.ExprSocketDef( + active_kind=ct.FlowKind.LazyArrayRange, + physical_type=spux.PhysicalType.Freq, + default_unit=spux.THz, + default_min=374.7406, ## 800nm + default_max=1498.962, ## 200nm + default_steps=100, ), }, 'Time Domain': { - 'Rec Start': sockets.PhysicalTimeSocketDef(), - 'Rec Stop': sockets.PhysicalTimeSocketDef(default_value=200 * spux.fs), - 'Samples/Time': sockets.IntegerNumberSocketDef( + 'Time Range': sockets.ExprSocketDef( + active_kind=ct.FlowKind.LazyArrayRange, + physical_type=spux.PhysicalType.Time, + default_unit=spu.picosecond, + default_min=0, + default_max=10, + default_steps=2, + ), + 'Temporal Subdivs': sockets.ExprSocketDef( + mathtype=spux.MathType.Integer, default_value=100, ), }, @@ -56,7 +79,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode): } #################### - # - Output Sockets + # - Output #################### @events.computes_output_socket( 'Freq Monitor', @@ -64,7 +87,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode): input_sockets={ 'Center', 'Size', - 'Samples/Space', + 'Spatial Subdivs', 'Freqs', }, input_socket_kinds={ @@ -93,12 +116,12 @@ class EHFieldMonitorNode(base.MaxwellSimNode): center=input_sockets['Center'], size=input_sockets['Size'], name=props['sim_node_name'], - interval_space=tuple(input_sockets['Samples/Space']), + interval_space=tuple(input_sockets['Spatial Subdivs']), freqs=input_sockets['Freqs'].realize().values, ) #################### - # - Preview - Changes to Input Sockets + # - Preview #################### @events.on_value_changed( socket_name={'Center', 'Size'}, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py index 79334c5..899473a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py @@ -1,6 +1,7 @@ import typing as typ import sympy as sp +import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes @@ -23,23 +24,43 @@ class PowerFluxMonitorNode(base.MaxwellSimNode): # - Sockets #################### input_sockets: typ.ClassVar = { - 'Center': sockets.PhysicalPoint3DSocketDef(), - 'Size': sockets.PhysicalSize3DSocketDef(), - 'Samples/Space': sockets.Integer3DVectorSocketDef( - default_value=sp.Matrix([10, 10, 10]) + 'Center': sockets.ExprSocketDef( + shape=(3,), + physical_type=spux.PhysicalType.Length, + ), + 'Size': sockets.ExprSocketDef( + shape=(3,), + physical_type=spux.PhysicalType.Length, + ), + 'Samples/Space': sockets.ExprSocketDef( + shape=(3,), + mathtype=spux.MathType.Integer, + default_value=sp.Matrix([10, 10, 10]), ), 'Direction': sockets.BoolSocketDef(), } input_socket_sets: typ.ClassVar = { 'Freq Domain': { - 'Freqs': sockets.PhysicalFreqSocketDef( - is_array=True, + 'Freqs': sockets.ExprSocketDef( + active_kind=ct.FlowKind.LazyArrayRange, + physical_type=spux.PhysicalType.Freq, + default_unit=spux.THz, + default_min=374.7406, ## 800nm + default_max=1498.962, ## 200nm + default_steps=100, ), }, 'Time Domain': { - 'Rec Start': sockets.PhysicalTimeSocketDef(), - 'Rec Stop': sockets.PhysicalTimeSocketDef(default_value=200 * spux.fs), - 'Samples/Time': sockets.IntegerNumberSocketDef( + 'Time Range': sockets.ExprSocketDef( + active_kind=ct.FlowKind.LazyArrayRange, + physical_type=spux.PhysicalType.Time, + default_unit=spu.picosecond, + default_min=0, + default_max=10, + default_steps=2, + ), + 'Samples/Time': sockets.ExprSocketDef( + mathtype=spux.MathType.Integer, default_value=100, ), }, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py index 9cf279f..ce2cc68 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py @@ -2,7 +2,7 @@ import typing as typ import tidy3d as td -from blender_maxwell.utils import analyze_geonodes, logger +from blender_maxwell.utils import bl_cache, logger from ... import bl_socket_map, managed_objs, sockets from ... import contracts as ct @@ -20,9 +20,9 @@ class GeoNodesStructureNode(base.MaxwellSimNode): # - Sockets #################### input_sockets: typ.ClassVar = { + 'GeoNodes': sockets.BlenderGeoNodesSocketDef(), 'Medium': sockets.MaxwellMediumSocketDef(), 'Center': sockets.PhysicalPoint3DSocketDef(), - 'GeoNodes': sockets.BlenderGeoNodesSocketDef(), } output_sockets: typ.ClassVar = { 'Structure': sockets.MaxwellStructureSocketDef(), @@ -34,7 +34,7 @@ class GeoNodesStructureNode(base.MaxwellSimNode): } #################### - # - Event Methods + # - Output #################### @events.computes_output_socket( 'Structure', @@ -46,9 +46,10 @@ class GeoNodesStructureNode(base.MaxwellSimNode): input_sockets: dict, managed_objs: dict, ) -> td.Structure: + """Computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`.""" # Simulate Input Value Change ## This ensures that the mesh has been re-computed. - self.on_input_changed() + self.on_input_socket_changed() ## TODO: mesh_as_arrays might not take the Center into account. ## - Alternatively, Tidy3D might have a way to transform? @@ -62,96 +63,109 @@ class GeoNodesStructureNode(base.MaxwellSimNode): ) #################### - # - Event Methods + # - Events: Preview Active Changed #################### @events.on_value_changed( - socket_name={'GeoNodes', 'Center'}, prop_name='preview_active', - any_loose_input_socket=True, - run_on_init=True, - # Pass Data props={'preview_active'}, + input_sockets={'Center'}, + managed_objs={'mesh'}, + ) + def on_preview_changed(self, props, input_sockets) -> None: + """Enables/disables previewing of the GeoNodes-driven mesh, regardless of whether a particular GeoNodes tree is chosen.""" + mesh = managed_objs['mesh'] + + # No Mesh: Create Empty Object + ## Ensures that when there is mesh data, it'll be correctly previewed. + ## Bit of a workaround - the idea is usually to make the MObj as needed. + if not mesh.exists: + center = input_sockets['Center'] + _ = mesh.bl_object(location=center) + + # Push Preview State to Managed Mesh + if props['preview_active']: + mesh.show_preview() + else: + mesh.hide_preview() + + #################### + # - Events: GN Input Changed + #################### + @events.on_value_changed( + socket_name={'Center'}, + any_loose_input_socket=True, + # Pass Data managed_objs={'mesh', 'modifier'}, input_sockets={'Center', 'GeoNodes'}, all_loose_input_sockets=True, unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, scale_input_sockets={'Center': 'BlenderUnits'}, ) - def on_input_changed( - self, - props: dict, - managed_objs: dict, - input_sockets: dict, - loose_input_sockets: dict, - unit_systems: dict, + def on_input_socket_changed( + self, input_sockets, loose_input_sockets, unit_systems ) -> None: - # No GeoNodes: Remove Modifier (if any) - if (geonodes := input_sockets['GeoNodes']) is None: - if ( - managed_objs['modifier'].name - in managed_objs['mesh'].bl_object().modifiers.keys().copy() - ): - managed_objs['modifier'].free_from_bl_object( - managed_objs['mesh'].bl_object() - ) + """Pushes any change in GeoNodes-bound input sockets to the GeoNodes modifier. - # Reset Loose Input Sockets - self.loose_input_sockets = {} - return + Also pushes the `Center:Value` socket to govern the object's center in 3D space. + """ + geonodes = input_sockets['GeoNodes'] + has_geonodes = not ct.FlowSignal.check(geonodes) - # No Loose Input Sockets: Create from GeoNodes Interface - ## TODO: Other reasons to trigger re-filling loose_input_sockets. - if not loose_input_sockets: - # Retrieve the GeoNodes Interface - geonodes_interface = analyze_geonodes.interface( - input_sockets['GeoNodes'], direc='INPUT' + if has_geonodes: + mesh = managed_objs['mesh'] + modifier = managed_objs['modifier'] + center = input_sockets['Center'] + unit_system = unit_systems['BlenderUnits'] + + # Push Loose Input Values to GeoNodes Modifier + modifier.bl_modifier( + mesh.bl_object(location=center), + 'NODES', + { + 'node_group': geonodes, + 'inputs': loose_input_sockets, + 'unit_system': unit_system, + }, ) + #################### + # - Events: GN Tree Changed + #################### + @events.on_value_changed( + socket_name={'GeoNodes'}, + # Pass Data + managed_objs={'mesh', 'modifier'}, + input_sockets={'GeoNodes', 'Center'}, + ) + def on_input_changed( + self, + managed_objs: dict, + input_sockets: dict, + ) -> None: + """Declares new loose input sockets in response to a new GeoNodes tree (if any).""" + geonodes = input_sockets['GeoNodes'] + has_geonodes = not ct.FlowSignal.check(geonodes) + + if has_geonodes: + mesh = managed_objs['mesh'] + modifier = managed_objs['modifier'] + # Fill the Loose Input Sockets + ## -> The SocketDefs contain the default values from the interface. log.info( 'Initializing GeoNodes Structure Node "%s" from GeoNodes Group "%s"', self.bl_label, str(geonodes), ) - self.loose_input_sockets = { - socket_name: bl_socket_map.socket_def_from_bl_socket(iface_socket)() - for socket_name, iface_socket in geonodes_interface.items() - } + self.loose_input_sockets = bl_socket_map.sockets_from_geonodes(geonodes) - # Set Loose Input Sockets to Interface (Default) Values - ## Changing socket.value invokes recursion of this function. - ## The else: below ensures that only one push occurs. - ## (well, one push per .value set, which simplifies to one push) - log.info( - 'Setting Loose Input Sockets of "%s" to GeoNodes Defaults', - self.bl_label, - ) - for socket_name in self.loose_input_sockets: - socket = self.inputs[socket_name] - socket.value = bl_socket_map.read_bl_socket_default_value( - geonodes_interface[socket_name], - unit_systems['BlenderUnits'], - allow_unit_not_in_unit_system=True, - ) - log.info( - 'Set Loose Input Sockets of "%s" to: %s', - self.bl_label, - str(self.loose_input_sockets), - ) - else: - # Push Loose Input Values to GeoNodes Modifier - managed_objs['modifier'].bl_modifier( - managed_objs['mesh'].bl_object(location=input_sockets['Center']), - 'NODES', - { - 'node_group': input_sockets['GeoNodes'], - 'unit_system': unit_systems['BlenderUnits'], - 'inputs': loose_input_sockets, - }, - ) - # Push Preview State - if props['preview_active']: - managed_objs['mesh'].show_preview() + ## -> The loose socket creation triggers 'on_input_socket_changed' + + elif self.loose_input_sockets: + self.loose_input_sockets = {} + + if modifier.name in mesh.bl_object().modifiers.keys().copy(): + modifier.free_from_bl_object(mesh.bl_object()) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/__init__.py index 0feb383..946f53d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/__init__.py @@ -1,11 +1,11 @@ from blender_maxwell.utils import logger from .. import contracts as ct -from . import basic, blender, maxwell, physical, tidy3d +from . import basic, blender, expr, maxwell, physical, tidy3d from .scan_socket_defs import scan_for_socket_defs log = logger.get(__name__) -sockets_modules = [basic, physical, blender, maxwell, tidy3d] +sockets_modules = [basic, blender, expr, maxwell, physical, tidy3d] #################### # - Scan for SocketDefs @@ -28,21 +28,24 @@ for socket_type in ct.SocketType: ): log.warning('Missing SocketDef for %s', socket_type.value) + #################### # - Exports #################### BL_REGISTER = [ *basic.BL_REGISTER, - *physical.BL_REGISTER, *blender.BL_REGISTER, + *expr.BL_REGISTER, *maxwell.BL_REGISTER, + *physical.BL_REGISTER, *tidy3d.BL_REGISTER, ] __all__ = [ 'basic', - 'physical', 'blender', + 'expr', 'maxwell', + 'physical', 'tidy3d', ] + [socket_def_type.__name__ for socket_def_type in SOCKET_DEFS.values()] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index ed33796..9708a6a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -5,10 +5,8 @@ from types import MappingProxyType import bpy import pydantic as pyd -import sympy as sp from blender_maxwell.utils import bl_cache, logger, serialize -from blender_maxwell.utils import extra_sympy_units as spux from .. import contracts as ct @@ -126,7 +124,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket): socket_color: tuple # Options - use_units: bool = False use_prelock: bool = False use_info_draw: bool = False @@ -210,35 +207,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket): 'active_kind', bpy.props.StringProperty, default=str(ct.FlowKind.Value) ) - # Configure Use of Units - if cls.use_units: - if not (socket_units := ct.SOCKET_UNITS.get(cls.socket_type)): - msg = f'{cls.socket_type}: Tried to define "use_units", but there is no unit for {cls.socket_type} defined in "contracts.SOCKET_UNITS"' - raise RuntimeError(msg) - - cls.set_prop( - 'active_unit', - bpy.props.EnumProperty, - name='Unit', - items=[ - (unit_name, spux.sp_to_str(unit_value), sp.srepr(unit_value)) - for unit_name, unit_value in socket_units['values'].items() - ], - default=socket_units['default'], - ) - cls.set_prop( - 'prev_active_unit', - bpy.props.StringProperty, - default=socket_units['default'], - ) - - #################### - # - Units - #################### - @property - def prev_unit(self) -> sp.Expr: - return self.possible_units[self.prev_active_unit] - #################### # - Property Event: On Update #################### @@ -250,36 +218,47 @@ class MaxwellSimSocket(bpy.types.NodeSocket): """ self.display_shape = ( 'SQUARE' if self.active_kind == ct.FlowKind.LazyValueRange else 'CIRCLE' - ) + ('_DOT' if self.use_units else '') + ) # + ('_DOT' if self.use_units else '') ## TODO: Valid Active Kinds should be a subset/subenum(?) of FlowKind + def on_socket_prop_changed(self, prop_name: str) -> None: + """Called when a property has been updated. + + Notes: + Can be overridden if a socket needs to respond to a property change. + + **Always prefer using node events when possible**. + Think very carefully before using this, and use it with the greatest of care. + + Attributes: + prop_name: The name of the property that was changed. + """ + def on_prop_changed(self, prop_name: str, _: bpy.types.Context) -> None: """Called when a property has been updated. Contrary to `node.on_prop_changed()`, socket-specific callbacks are baked into this function: - **Active Kind** (`self.active_kind`): Sets the socket shape to reflect the active `FlowKind`. - - **Unit** (`self.unit`): Corrects the internal `FlowKind` representation to match the new unit. Attributes: prop_name: The name of the property that was changed. """ - # Property: Active Kind - if prop_name == 'active_kind': - self._on_active_kind_changed() - elif prop_name == 'unit': - self._on_unit_changed() - - # Valid Properties - elif hasattr(self, prop_name): + if hasattr(self, prop_name): # Invalidate UI BLField Caches if prop_name in self.ui_blfields: setattr(self, prop_name, bl_cache.Signal.InvalidateCache) + # Property Callbacks: Active Kind + if prop_name == 'active_kind': + self._on_active_kind_changed() + + # Property Callbacks: Per-Socket + self.on_socket_prop_changed(prop_name) + # Trigger Event self.trigger_event(ct.FlowEvent.DataChanged) - # Undefined Properties else: msg = f'Property {prop_name} not defined on socket {self.bl_label} ({self.socket_type})' raise RuntimeError(msg) @@ -760,7 +739,6 @@ class MaxwellSimSocket(bpy.types.NodeSocket): - **Locked** (`self.locked`): The UI will be unusable. - **Linked** (`self.is_linked`): Only the socket label will display. - - **Use Units** (`self.use_units`): The currently active unit will display as a dropdown menu. - **Use Prelock** (`self.use_prelock`): The "prelock" UI drawn with `self.draw_prelock()`, which shows **regardless of `self.locked`**. - **FlowKind**: The `FlowKind`-specific UI corresponding to the current `self.active_kind`. @@ -787,17 +765,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket): if self.is_linked: self.draw_input_label_row(row, text) else: - # User Label Row (incl. Units) - if self.use_units: - split = row.split(factor=0.6, align=True) - - _row = split.row(align=True) - self.draw_label_row(_row, text) - - _col = split.column(align=True) - _col.prop(self, 'active_unit', text='') - else: - self.draw_label_row(row, text) + self.draw_label_row(row, text) # User Prelock Row row = col.row(align=False) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py index 0228867..8009b11 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/__init__.py @@ -1,20 +1,16 @@ from . import any as any_socket from . import bool as bool_socket -from . import expr, file_path, string, data +from . import file_path, string AnySocketDef = any_socket.AnySocketDef -DataSocketDef = data.DataSocketDef BoolSocketDef = bool_socket.BoolSocketDef StringSocketDef = string.StringSocketDef FilePathSocketDef = file_path.FilePathSocketDef -ExprSocketDef = expr.ExprSocketDef BL_REGISTER = [ *any_socket.BL_REGISTER, - *data.BL_REGISTER, *bool_socket.BL_REGISTER, *string.BL_REGISTER, *file_path.BL_REGISTER, - *expr.BL_REGISTER, ] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/blender/geonodes.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/blender/geonodes.py index 988f6a8..d7cf5ad 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/blender/geonodes.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/blender/geonodes.py @@ -70,8 +70,8 @@ class BlenderGeoNodesBLSocket(base.MaxwellSimSocket): # - Default Value #################### @property - def value(self) -> bpy.types.NodeTree | None: - return self.raw_value + def value(self) -> bpy.types.NodeTree | ct.FlowSignal: + return self.raw_value if self.raw_value is not None else ct.FlowSignal.NoFlow @value.setter def value(self, value: bpy.types.NodeTree) -> None: diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index 7a1be91..ab5b285 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -7,16 +7,14 @@ import sympy as sp from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import extra_sympy_units as spux -from ... import contracts as ct -from .. import base +from .. import contracts as ct +from . import base ## TODO: This is a big node, and there's a lot to get right. -## - Dynamically adjust the value when the user changes the unit in the UI. ## - Dynamically adjust socket color in response to, especially, the unit dimension. ## - Iron out the meaning of display shapes. ## - Generally pay attention to validity checking; it's make or break. -## - For array generation, it may pay to have both a symbolic expression (producing output according to `size` as usual) denoting how to actually make values, and how many. Enables ex. easy symbolic -## - For array generation, it may pay to have both a symbolic expression (producing output according to `size` as usual) +## - For array generation, it may pay to have both a symbolic expression (producing output according to `shape` as usual) denoting how to actually make values, and how many. Enables ex. easy symbolic plots. log = logger.get(__name__) @@ -69,24 +67,15 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### # - Properties #################### - size: typ.Literal[None, 2, 3] = bl_cache.BLField(None, prop_ui=True) + shape: tuple[int, ...] | None = bl_cache.BLField(None) mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real, prop_ui=True) + physical_type: spux.PhysicalType | None = bl_cache.BLField(None) symbols: frozenset[spux.Symbol] = bl_cache.BLField(frozenset()) - ## Units - unit_dim: spux.UnitDimension | None = bl_cache.BLField(None) active_unit: enum.Enum = bl_cache.BLField( None, enum_cb=lambda self, _: self.search_units(), prop_ui=True ) - ## Info Display - show_info_columns: bool = bl_cache.BLField(False, prop_ui=True) - info_columns: InfoDisplayCol = bl_cache.BLField( - {InfoDisplayCol.MathType, InfoDisplayCol.Unit}, - prop_ui=True, - enum_many=True, - ) - # UI: Value ## Expression raw_value_spstr: str = bl_cache.BLField('', prop_ui=True) @@ -94,7 +83,7 @@ class ExprBLSocket(base.MaxwellSimSocket): raw_value_int: int = bl_cache.BLField(0, prop_ui=True) raw_value_rat: Int2 = bl_cache.BLField((0, 1), prop_ui=True) raw_value_float: float = bl_cache.BLField(0.0, float_prec=4, prop_ui=True) - raw_value_complex: Float2 = bl_cache.BLField((0, 1), float_prec=4, prop_ui=True) + raw_value_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=4, prop_ui=True) ## 2D raw_value_int2: Int2 = bl_cache.BLField((0, 0), prop_ui=True) raw_value_rat2: Int22 = bl_cache.BLField(((0, 1), (0, 1)), prop_ui=True) @@ -105,7 +94,9 @@ class ExprBLSocket(base.MaxwellSimSocket): ## 3D raw_value_int3: Int3 = bl_cache.BLField((0, 0, 0), prop_ui=True) raw_value_rat3: Int32 = bl_cache.BLField(((0, 1), (0, 1), (0, 1)), prop_ui=True) - raw_value_float3: Float3 = bl_cache.BLField((0.0, 0.0), float_prec=4, prop_ui=True) + raw_value_float3: Float3 = bl_cache.BLField( + (0.0, 0.0, 0.0), float_prec=4, prop_ui=True + ) raw_value_complex3: Float32 = bl_cache.BLField( ((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4, prop_ui=True ) @@ -123,6 +114,14 @@ class ExprBLSocket(base.MaxwellSimSocket): ((0.0, 0.0), (1.0, 1.0)), float_prec=4, prop_ui=True ) + # UI: Info + show_info_columns: bool = bl_cache.BLField(False, prop_ui=True) + info_columns: InfoDisplayCol = bl_cache.BLField( + {InfoDisplayCol.MathType, InfoDisplayCol.Unit}, + prop_ui=True, + enum_many=True, + ) + #################### # - Computed: Raw Expressions #################### @@ -142,35 +141,63 @@ class ExprBLSocket(base.MaxwellSimSocket): # - Computed: Units #################### def search_units(self, _: bpy.types.Context) -> list[ct.BLEnumElement]: - if self.unit_dim is not None: + if self.physical_type is not None: return [ (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i) - for i, unit in enumerate(spux.unit_dim_units(self.unit_dim)) + for i, unit in enumerate(self.physical_type.valid_units) ] return [] - @property + @bl_cache.cached_bl_property() def unit(self) -> spux.Unit | None: - if self.active_unit != 'NONE': + """Gets the current active unit. + + Returns: + The current active `sympy` unit. + + If the socket expression is unitless, this returns `None`. + """ + if self.active_unit is not None: return spux.unit_str_to_unit(self.active_unit) return None @unit.setter def unit(self, unit: spux.Unit) -> None: - valid_units = spux.unit_dim_units(self.unit_dim) - if unit in valid_units: + """Set the unit, without touching the `raw_*` UI properties. + + Notes: + To set a new unit, **and** convert the `raw_*` UI properties to the new unit, use `self.convert_unit()` instead. + """ + if unit in self.physical_type.valid_units: self.active_unit = sp.sstr(unit) - msg = f'Tried to set invalid unit {unit} (unit dim "{self.unit_dim}" only supports "{valid_units}")' + msg = f'Tried to set invalid unit {unit} (physical type "{self.physical_type}" only supports "{self.physical_type.valid_units}")' raise ValueError(msg) + def convert_unit(self, unit_to: spux.Unit) -> None: + if self.active_kind == ct.FlowKind.Value: + current_value = self.value + self.unit = unit_to + self.value = current_value + elif self.active_kind == ct.FlowKind.LazyArrayRange: + current_lazy_array_range = self.lazy_array_range + self.unit = unit_to + self.lazy_array_range = current_lazy_array_range + + #################### + # - Property Callback + #################### + def on_socket_prop_changed(self, prop_name: str) -> None: + if prop_name == 'unit' and self.active_unit is not None: + self.convert_unit(spux.unit_str_to_unit(self.active_unit)) + #################### # - Methods #################### def _parse_expr_info( self, expr: spux.SympyExpr - ) -> tuple[spux.MathType, typ.Literal[None, 2, 3], spux.UnitDimension]: + ) -> tuple[spux.MathType, tuple[int, ...] | None, spux.UnitDimension]: # Parse MathType mathtype = spux.MathType.from_expr(expr) if self.mathtype != mathtype: @@ -188,18 +215,12 @@ class ExprBLSocket(base.MaxwellSimSocket): raise ValueError(msg) # Parse Dimensions - size = spux.parse_size(expr) - if size != self.size: - msg = f'Expr {expr} has {size} dimensions, which is incompatible with the expr socket ({self.size} dimensions)' + shape = spux.parse_shape(expr) + if shape != self.shape: + msg = f'Expr {expr} has shape {shape}, which is incompatible with the expr socket (shape {self.shape})' raise ValueError(msg) - # Parse Unit Dimension - unit_dim = spux.parse_unit_dim(expr) - if unit_dim != self.unit_dim: - msg = f'Expr {expr} has unit dimension {unit_dim}, which is incompatible with socket unit dimension {self.unit_dim}' - raise ValueError(msg) - - return mathtype, size, unit_dim + return mathtype, shape def _to_raw_value(self, expr: spux.SympyExpr): if self.unit is not None: @@ -212,7 +233,7 @@ class ExprBLSocket(base.MaxwellSimSocket): locals={sym.name: sym for sym in self.symbols}, strict=False, convert_xor=True, - ).subs(spux.ALL_UNIT_SYMBOLS) * (self.unit if self.unit is not None else 1) + ).subs(spux.UNIT_BY_SYMBOL) * (self.unit if self.unit is not None else 1) # Try Parsing and Returning the Expression try: @@ -245,7 +266,7 @@ class ExprBLSocket(base.MaxwellSimSocket): Return: The expression defined by the socket, in the socket's unit. """ - if self.symbols: + if self.symbols or self.shape not in [None, (2,), (3,)]: expr = self.raw_value_sp if expr is None: return ct.FlowSignal.FlowPending @@ -266,7 +287,7 @@ class ExprBLSocket(base.MaxwellSimSocket): self.raw_value_complex[0] + sp.I * self.raw_value_complex[1] ), }, - 2: { + (2,): { MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int2]), MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat2]), MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float2]), @@ -274,7 +295,7 @@ class ExprBLSocket(base.MaxwellSimSocket): [c[0] + sp.I * c[1] for c in self.raw_value_complex2] ), }, - 3: { + (3,): { MT_Z: lambda: sp.Matrix([Z(i) for i in self.raw_value_int3]), MT_Q: lambda: sp.Matrix([Q(q[0], q[1]) for q in self.raw_value_rat3]), MT_R: lambda: sp.Matrix([R(r) for r in self.raw_value_float3]), @@ -282,7 +303,7 @@ class ExprBLSocket(base.MaxwellSimSocket): [c[0] + sp.I * c[1] for c in self.raw_value_complex3] ), }, - }[self.size][self.mathtype]() * (self.unit if self.unit is not None else 1) + }[self.shape][self.mathtype]() * (self.unit if self.unit is not None else 1) @value.setter def value(self, expr: spux.SympyExpr) -> None: @@ -291,8 +312,8 @@ class ExprBLSocket(base.MaxwellSimSocket): Notes: Called to set the internal `FlowKind.Value` of this socket. """ - mathtype, size, unit_dim = self._parse_expr_info(expr) - if self.symbols: + mathtype, shape = self._parse_expr_info(expr) + if self.symbols or self.shape not in [None, (2,), (3,)]: self.raw_value_spstr = sp.sstr(expr) else: @@ -300,7 +321,7 @@ class ExprBLSocket(base.MaxwellSimSocket): MT_Q = spux.MathType.Rational MT_R = spux.MathType.Real MT_C = spux.MathType.Complex - if size is None: + if shape is None: if mathtype == MT_Z: self.raw_value_int = self._to_raw_value(expr) elif mathtype == MT_Q: @@ -309,7 +330,7 @@ class ExprBLSocket(base.MaxwellSimSocket): self.raw_value_float = self._to_raw_value(expr) elif mathtype == MT_C: self.raw_value_complex = self._to_raw_value(expr) - elif size == 2: + elif shape == (2,): if mathtype == MT_Z: self.raw_value_int2 = self._to_raw_value(expr) elif mathtype == MT_Q: @@ -318,7 +339,7 @@ class ExprBLSocket(base.MaxwellSimSocket): self.raw_value_float2 = self._to_raw_value(expr) elif mathtype == MT_C: self.raw_value_complex2 = self._to_raw_value(expr) - elif size == 3: + elif shape == (3,): if mathtype == MT_Z: self.raw_value_int3 = self._to_raw_value(expr) elif mathtype == MT_Q: @@ -341,9 +362,35 @@ class ExprBLSocket(base.MaxwellSimSocket): Return: The range of lengths, which uses no symbols. """ + if self.symbols: + return ct.LazyArrayRangeFlow( + start=self.raw_min_sp, + stop=self.raw_max_sp, + steps=self.steps, + scaling='lin', + unit=self.unit, + symbols=self.symbols, + ) + + MT_Z = spux.MathType.Integer + MT_Q = spux.MathType.Rational + MT_R = spux.MathType.Real + MT_C = spux.MathType.Complex + Z = sp.Integer + Q = sp.Rational + R = sp.RealNumber + min_bound, max_bound = { + MT_Z: lambda: [Z(bound) for bound in self.raw_range_int], + MT_Q: lambda: [Q(bound[0], bound[1]) for bound in self.raw_range_rat], + MT_R: lambda: [R(bound) for bound in self.raw_range_float], + MT_C: lambda: [ + bound[0] + sp.I * bound[1] for bound in self.raw_range_complex + ], + }[self.mathtype]() + return ct.LazyArrayRangeFlow( - start=sp.S(self.min_value) * self.unit, - stop=sp.S(self.max_value) * self.unit, + start=min_bound, + stop=max_bound, steps=self.steps, scaling='lin', unit=self.unit, @@ -356,25 +403,74 @@ class ExprBLSocket(base.MaxwellSimSocket): Notes: Called to compute the internal `FlowKind.LazyArrayRange` of this socket. """ - self.min_value = spux.sympy_to_python( - spux.scale_to_unit(value.start * value.unit, self.unit) - ) - self.max_value = spux.sympy_to_python( - spux.scale_to_unit(value.stop * value.unit, self.unit) - ) self.steps = value.steps + self.unit = value.unit + + if self.symbols: + self.raw_min_spstr = sp.sstr(value.start) + self.raw_max_spstr = sp.sstr(value.stop) + + else: + MT_Z = spux.MathType.Integer + MT_Q = spux.MathType.Rational + MT_R = spux.MathType.Real + MT_C = spux.MathType.Complex + + if value.mathtype == MT_Z: + self.raw_range_int = [ + self._to_raw_value(bound) for bound in [value.start, value.stop] + ] + elif value.mathtype == MT_Q: + self.raw_range_rat = [ + self._to_raw_value(bound) for bound in [value.start, value.stop] + ] + elif value.mathtype == MT_R: + self.raw_range_float = [ + self._to_raw_value(bound) for bound in [value.start, value.stop] + ] + elif value.mathtype == MT_C: + self.raw_range_complex = [ + self._to_raw_value(bound) for bound in [value.start, value.stop] + ] #################### - # - FlowKind: LazyValueFunc + # - FlowKind: LazyValueFunc (w/Params if Constant) #################### @property def lazy_value_func(self) -> ct.LazyValueFuncFlow: + # Lazy Value: Arbitrary Expression + if self.symbols or self.shape not in [None, (2,), (3,)]: + return ct.LazyValueFuncFlow( + func=sp.lambdify(self.symbols, self.value, 'jax'), + func_args=[spux.MathType.from_expr(sym) for sym in self.symbols], + supports_jax=True, + ) + + # Lazy Value: Constant + ## -> A very simple function, which takes a single argument. + ## -> What will be passed is a unit-scaled/stripped, pytype-converted Expr:Value. + ## -> Until then, the user can utilize this LVF in a function composition chain. return ct.LazyValueFuncFlow( - func=sp.lambdify(self.symbols, self.value, 'jax'), - func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols], + func=lambda v: v, + func_args=[ + self.physical_type if self.physical_type is not None else self.mathtype + ], supports_jax=True, ) + @property + def params(self) -> ct.ParamsFlow: + # Params Value: Symbolic + ## -> The Expr socket does not declare actual values for the symbols. + ## -> Those values must come from elsewhere. + ## -> If someone tries to load them anyway, tell them 'NoFlow'. + if self.symbols or self.shape not in [None, (2,), (3,)]: + return ct.FlowSignal.NoFlow + + # Params Value: Constant + ## -> Simply pass the Expr:Value as parameter. + return ct.ParamsFlow(func_args=[self.value]) + #################### # - FlowKind: Array #################### @@ -396,7 +492,7 @@ class ExprBLSocket(base.MaxwellSimSocket): def info(self) -> ct.ArrayFlow: return ct.InfoFlow( output_name='_', ## TODO: Something else - output_shape=(self.size,) if self.size is not None else None, + output_shape=self.shape, output_mathtype=self.mathtype, output_unit=self.unit, ) @@ -416,6 +512,16 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### # - UI #################### + def draw_label_row(self, row: bpy.types.UILayout, text) -> None: + if self.active_unit is not None: + split = row.split(factor=0.6, align=True) + + _row = split.row(align=True) + _row.label(text=text) + + _col = split.column(align=True) + _col.prop(self, 'active_unit', text='') + def draw_value(self, col: bpy.types.UILayout) -> None: # Property Interface if self.symbols: @@ -426,7 +532,7 @@ class ExprBLSocket(base.MaxwellSimSocket): MT_Q = spux.MathType.Rational MT_R = spux.MathType.Real MT_C = spux.MathType.Complex - if self.size is None: + if self.shape is None: if self.mathtype == MT_Z: col.prop(self, self.blfields['raw_value_int'], text='') elif self.mathtype == MT_Q: @@ -435,7 +541,7 @@ class ExprBLSocket(base.MaxwellSimSocket): col.prop(self, self.blfields['raw_value_float'], text='') elif self.mathtype == MT_C: col.prop(self, self.blfields['raw_value_complex'], text='') - elif self.size == 2: + elif self.shape == (2,): if self.mathtype == MT_Z: col.prop(self, self.blfields['raw_value_int2'], text='') elif self.mathtype == MT_Q: @@ -444,7 +550,7 @@ class ExprBLSocket(base.MaxwellSimSocket): col.prop(self, self.blfields['raw_value_float2'], text='') elif self.mathtype == MT_C: col.prop(self, self.blfields['raw_value_complex2'], text='') - elif self.size == 3: + elif self.shape == (3,): if self.mathtype == MT_Z: col.prop(self, self.blfields['raw_value_int3'], text='') elif self.mathtype == MT_Q: @@ -579,34 +685,57 @@ class ExprSocketDef(base.SocketDef): ct.FlowKind.Value ) - # Properties - size: typ.Literal[None, 2, 3] = None + # Socket Interface + ## TODO: __hash__ like socket method based on these? + shape: tuple[int, ...] | None = None mathtype: spux.MathType = spux.MathType.Real + physical_type: spux.PhysicalType | None = None symbols: frozenset[spux.Symbol] = frozenset() - ## Units - unit_dim: spux.UnitDimension | None = None - ## Info Display - show_info_columns: bool = False + + # Socket Units + default_unit: spux.Unit | None = None + + # FlowKind: Value + default_value: spux.SympyExpr = sp.S(0) + + # FlowKind: LazyArrayRange + default_min: spux.SympyExpr = sp.S(0) + default_max: spux.SympyExpr = sp.S(1) + default_steps: int = 2 + ## TODO: Configure lin/log/... scaling (w/enumprop in UI) ## TODO: Buncha validation :) - # Defaults - default_unit: spux.Unit | None = None - default_value: spux.SympyExpr = sp.S(1) - default_min: spux.SympyExpr = sp.S(0) - default_max: spux.SympyExpr = sp.S(1) - default_steps: spux.SympyExpr = sp.S(2) + # UI + show_info_columns: bool = False def init(self, bl_socket: ExprBLSocket) -> None: bl_socket.active_kind = self.active_kind - bl_socket.size = self.size - bl_socket.mathtype = self.size - bl_socket.symbols = self.symbols - bl_socket.unit_dim = self.size - bl_socket.unit = self.symbols - bl_socket.show_info_columns = self.show_info_columns - bl_socket.value = self.default + # Socket Interface + bl_socket.shape = self.shape + bl_socket.mathtype = self.mathtype + bl_socket.physical_type = self.physical_type + bl_socket.symbols = self.symbols + + # Socket Units + if self.default_unit is not None: + bl_socket.unit = self.default_unit + + # FlowKind: Value + bl_socket.value = self.default_value + + # FlowKind: LazyArrayRange + bl_socket.lazy_array_range = ct.LazyArrayRangeFlow( + start=self.default_min, + stop=self.default_max, + steps=self.default_steps, + scaling='lin', + unit=self.default_unit, + ) + + # UI + bl_socket.show_info_columns = self.show_info_columns #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/__init__.py index 42e3a25..b65cb63 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/__init__.py @@ -1,49 +1,51 @@ -from . import bound_cond, bound_conds +from . import ( + bound_cond, + bound_conds, + fdtd_sim, + fdtd_sim_data, + medium, + medium_non_linearity, + monitor, + monitor_data, + sim_domain, + sim_grid, + sim_grid_axis, + source, + structure, + temporal_shape, +) MaxwellBoundCondSocketDef = bound_cond.MaxwellBoundCondSocketDef MaxwellBoundCondsSocketDef = bound_conds.MaxwellBoundCondsSocketDef - -from . import medium, medium_non_linearity - +MaxwellFDTDSimSocketDef = fdtd_sim.MaxwellFDTDSimSocketDef +MaxwellFDTDSimDataSocketDef = fdtd_sim_data.MaxwellFDTDSimDataSocketDef MaxwellMediumSocketDef = medium.MaxwellMediumSocketDef MaxwellMediumNonLinearitySocketDef = ( medium_non_linearity.MaxwellMediumNonLinearitySocketDef ) - -from . import source, temporal_shape - -MaxwellSourceSocketDef = source.MaxwellSourceSocketDef -MaxwellTemporalShapeSocketDef = temporal_shape.MaxwellTemporalShapeSocketDef - -from . import structure - -MaxwellStructureSocketDef = structure.MaxwellStructureSocketDef - -from . import monitor - MaxwellMonitorSocketDef = monitor.MaxwellMonitorSocketDef - -from . import fdtd_sim, fdtd_sim_data, sim_domain, sim_grid, sim_grid_axis - -MaxwellFDTDSimSocketDef = fdtd_sim.MaxwellFDTDSimSocketDef -MaxwellFDTDSimDataSocketDef = fdtd_sim_data.MaxwellFDTDSimDataSocketDef +MaxwellMonitorDataSocketDef = monitor_data.MaxwellMonitorDataSocketDef +MaxwellSimDomainSocketDef = sim_domain.MaxwellSimDomainSocketDef MaxwellSimGridSocketDef = sim_grid.MaxwellSimGridSocketDef MaxwellSimGridAxisSocketDef = sim_grid_axis.MaxwellSimGridAxisSocketDef -MaxwellSimDomainSocketDef = sim_domain.MaxwellSimDomainSocketDef +MaxwellSourceSocketDef = source.MaxwellSourceSocketDef +MaxwellStructureSocketDef = structure.MaxwellStructureSocketDef +MaxwellTemporalShapeSocketDef = temporal_shape.MaxwellTemporalShapeSocketDef BL_REGISTER = [ *bound_cond.BL_REGISTER, *bound_conds.BL_REGISTER, - *medium.BL_REGISTER, - *medium_non_linearity.BL_REGISTER, - *source.BL_REGISTER, - *temporal_shape.BL_REGISTER, - *structure.BL_REGISTER, - *monitor.BL_REGISTER, *fdtd_sim.BL_REGISTER, *fdtd_sim_data.BL_REGISTER, + *medium.BL_REGISTER, + *medium_non_linearity.BL_REGISTER, + *monitor.BL_REGISTER, + *monitor_data.BL_REGISTER, + *sim_domain.BL_REGISTER, *sim_grid.BL_REGISTER, *sim_grid_axis.BL_REGISTER, - *sim_domain.BL_REGISTER, + *source.BL_REGISTER, + *structure.BL_REGISTER, + *temporal_shape.BL_REGISTER, ] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor_data.py new file mode 100644 index 0000000..35fd788 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/monitor_data.py @@ -0,0 +1,25 @@ +from ... import contracts as ct +from .. import base + + +class MaxwellMonitorDataBLSocket(base.MaxwellSimSocket): + socket_type = ct.SocketType.MaxwellMonitorData + bl_label = 'Maxwell Monitor Data' + + +#################### +# - Socket Configuration +#################### +class MaxwellMonitorDataSocketDef(base.SocketDef): + socket_type: ct.SocketType = ct.SocketType.MaxwellMonitorData + + def init(self, bl_socket: MaxwellMonitorDataBLSocket) -> None: + pass + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + MaxwellMonitorDataBLSocket, +] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/__init__.py index f7120bf..02f88e6 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/__init__.py @@ -1,9 +1,9 @@ -from . import pol, unit_system +from . import pol # , unit_system PhysicalPolSocketDef = pol.PhysicalPolSocketDef BL_REGISTER = [ - *unit_system.BL_REGISTER, + # *unit_system.BL_REGISTER, *pol.BL_REGISTER, ] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py index 3f8c7c2..af99c9c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py @@ -3,12 +3,12 @@ import sympy as sp import sympy.physics.optics.polarization as spo_pol import sympy.physics.units as spu -from blender_maxwell.utils.pydantic_sympy import SympyExpr +from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct from .. import base -StokesVector = SympyExpr +StokesVector = spux.SympyExpr class PhysicalPolBLSocket(base.MaxwellSimSocket): diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/unit_system.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/unit_system.py index c20c63d..662a2b4 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/unit_system.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/unit_system.py @@ -1,6 +1,6 @@ import bpy -from blender_maxwell.utils.pydantic_sympy import SympyExpr +#from blender_maxwell.utils.pydantic_sympy import SympyExpr from ... import contracts as ct from .. import base diff --git a/src/blender_maxwell/utils/__init__.py b/src/blender_maxwell/utils/__init__.py index bba10ed..32271c9 100644 --- a/src/blender_maxwell/utils/__init__.py +++ b/src/blender_maxwell/utils/__init__.py @@ -1,16 +1,22 @@ from ..nodeps.utils import blender_type_enum, pydeps from . import ( - analyze_geonodes, + bl_cache, extra_sympy_units, + image_ops, logger, - pydantic_sympy, + sci_constants, + serialize, + staticproperty, ) __all__ = [ - 'pydeps', - 'analyze_geonodes', 'blender_type_enum', + 'pydeps', + 'bl_cache', 'extra_sympy_units', + 'image_ops', 'logger', - 'pydantic_sympy', + 'sci_constants', + 'serialize', + 'staticproperty', ] diff --git a/src/blender_maxwell/utils/analyze_geonodes.py b/src/blender_maxwell/utils/analyze_geonodes.py deleted file mode 100644 index 413592b..0000000 --- a/src/blender_maxwell/utils/analyze_geonodes.py +++ /dev/null @@ -1,30 +0,0 @@ -import typing as typ - -import bpy - -INVALID_BL_SOCKET_TYPES = { - 'NodeSocketGeometry', -} - - -def interface( - geonodes: bpy.types.GeometryNodeTree, ## TODO: bpy type - direc: typ.Literal['INPUT', 'OUTPUT'], -): - """Returns 'valid' GeoNodes interface sockets. - - - The Blender socket type is not something invalid (ex. "Geometry"). - - The socket has a default value. - - The socket's direction (input/output) matches the requested direction. - """ - return { - interface_item_name: bl_interface_socket - for interface_item_name, bl_interface_socket in ( - geonodes.interface.items_tree.items() - ) - if ( - bl_interface_socket.socket_type not in INVALID_BL_SOCKET_TYPES - and hasattr(bl_interface_socket, 'default_value') - and bl_interface_socket.in_out == direc - ) - } diff --git a/src/blender_maxwell/utils/bl_cache.py b/src/blender_maxwell/utils/bl_cache.py index b541360..3464793 100644 --- a/src/blender_maxwell/utils/bl_cache.py +++ b/src/blender_maxwell/utils/bl_cache.py @@ -10,6 +10,7 @@ import uuid from pathlib import Path import bpy +import numpy as np from blender_maxwell import contracts as ct from blender_maxwell.utils import logger, serialize @@ -515,6 +516,7 @@ class BLField: use_prop_update: Configures the BLField to run `bl_instance.on_prop_changed(attr_name)` whenever value is set. This is done by setting the `update` method. enum_cb: Method used to generate new enum elements whenever `Signal.ResetEnum` is presented. + matrix_rowmajor: Blender's UI stores matrices flattened, """ log.debug( @@ -528,8 +530,8 @@ class BLField: ## Static self._prop_ui = prop_ui self._prop_flags = prop_flags - self._min = abs_min - self._max = abs_max + self._abs_min = abs_min + self._abs_max = abs_max self._soft_min = soft_min self._soft_max = soft_max self._float_step = float_step @@ -545,6 +547,12 @@ class BLField: self._str_cb = str_cb self._enum_cb = enum_cb + ## Vector/Matrix Identity + ## -> Matrix Shape assists in the workaround for Matrix Display Bug + self._is_vector = False + self._is_matrix = False + self._matrix_shape = None + ## HUGE TODO: Persist these self._str_cb_cache = {} self._enum_cb_cache = {} @@ -637,6 +645,7 @@ class BLField: ## Reusable Snippets def _add_min_max_kwargs(): + nonlocal kwargs_prop ## I've heard legends of needing this! kwargs_prop |= {'min': self._abs_min} if self._abs_min is not None else {} kwargs_prop |= {'max': self._abs_max} if self._abs_max is not None else {} kwargs_prop |= ( @@ -647,6 +656,7 @@ class BLField: ) def _add_float_kwargs(): + nonlocal kwargs_prop kwargs_prop |= ( {'step': self._float_step} if self._float_step is not None else {} ) @@ -684,6 +694,7 @@ class BLField: default_value = self._default_value BLProp = bpy.props.BoolVectorProperty kwargs_prop |= {'size': len(typ.get_args(AttrType))} + self._is_vector = True ## Vector Int elif typ.get_origin(AttrType) is tuple and all( @@ -693,6 +704,7 @@ class BLField: BLProp = bpy.props.IntVectorProperty _add_min_max_kwargs() kwargs_prop |= {'size': len(typ.get_args(AttrType))} + self._is_vector = True ## Vector Float elif typ.get_origin(AttrType) is tuple and all( @@ -703,6 +715,59 @@ class BLField: _add_min_max_kwargs() _add_float_kwargs() kwargs_prop |= {'size': len(typ.get_args(AttrType))} + self._is_vector = True + + ## Matrix Bool + elif typ.get_origin(AttrType) is tuple and all( + all(V is bool for V in typ.get_args(T)) for T in typ.get_args(AttrType) + ): + # Workaround for Matrix Display Bug + ## - Also requires __get__ support to read consistently. + rows = len(typ.get_args(AttrType)) + cols = len(typ.get_args(typ.get_args(AttrType)[0])) + default_value = ( + np.array(self._default_value, dtype=bool) + .flatten() + .reshape([cols, rows]) + ).tolist() + BLProp = bpy.props.BoolVectorProperty + kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'} + ## 'size' has column-major ordering (Matrix Display Bug). + self._is_matrix = True + self._matrix_shape = (rows, cols) + + ## Matrix Int + elif typ.get_origin(AttrType) is tuple and all( + all(V is int for V in typ.get_args(T)) for T in typ.get_args(AttrType) + ): + _add_min_max_kwargs() + rows = len(typ.get_args(AttrType)) + cols = len(typ.get_args(typ.get_args(AttrType)[0])) + default_value = ( + np.array(self._default_value, dtype=int).flatten().reshape([cols, rows]) + ).tolist() + BLProp = bpy.props.IntVectorProperty + kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'} + self._is_matrix = True + self._matrix_shape = (rows, cols) + + ## Matrix Float + elif typ.get_origin(AttrType) is tuple and all( + all(V is float for V in typ.get_args(T)) for T in typ.get_args(AttrType) + ): + _add_min_max_kwargs() + _add_float_kwargs() + rows = len(typ.get_args(AttrType)) + cols = len(typ.get_args(typ.get_args(AttrType)[0])) + default_value = ( + np.array(self._default_value, dtype=float) + .flatten() + .reshape([cols, rows]) + ).tolist() + BLProp = bpy.props.FloatVectorProperty + kwargs_prop |= {'size': (cols, rows), 'subtype': 'MATRIX'} + self._is_matrix = True + self._matrix_shape = (rows, cols) ## Generic String elif AttrType is str: @@ -732,7 +797,7 @@ class BLField: } ## StrEnum - elif issubclass(AttrType, enum.StrEnum): + elif inspect.isclass(AttrType) and issubclass(AttrType, enum.StrEnum): default_value = self._default_value BLProp = bpy.props.EnumProperty kwargs_prop |= { @@ -792,6 +857,7 @@ class BLField: ) ## TODO: Mine description from owner class __doc__ # Define Property Getter + ## Serialized properties need to deserialize in the getter. if prop_is_serialized: def getter(_self: BLInstance) -> AttrType: @@ -802,6 +868,7 @@ class BLField: return getattr(_self, bl_attr_name) # Define Property Setter + ## Serialized properties need to serialize in the setter. if prop_is_serialized: def setter(_self: BLInstance, value: AttrType) -> None: @@ -821,7 +888,40 @@ class BLField: def __get__( self, bl_instance: BLInstance | None, owner: type[BLInstance] ) -> typ.Any: - return self._cached_bl_property.__get__(bl_instance, owner) + value = self._cached_bl_property.__get__(bl_instance, owner) + + # enum.Enum: Cast Auto-Injected Dynamic Enum 'NONE' -> None + ## As far a Blender is concerned, dynamic enum props can't be empty. + ## -> Well, they can... But bad things happen. So they can't. + ## So in the interest of the user's sanity, we always ensure one entry. + ## -> This one entry always has the one, same, id: 'NONE'. + ## Of course, we often want to check for this "there was nothing" case. + ## -> Aka, we want to do a `None` check, semantically speaking. + ## -> ...But because it's a special thingy, we must check 'NONE'? + ## Nonsense. Let the user just check `None`, as Guido intended. + if self._enum_cb is not None and value == 'NONE': + ## TODO: Perhaps check if the unsafe callback was actually []. + ## -> In case the user themselves want to return 'NONE'. + ## -> Why would they do this? Because they are users! + return None + + # Sized Vectors/Matrices + ## Why not just yeet back a np.array? + ## -> Type-annotating a shaped numpy array is... "rough". + ## -> Type-annotation tuple[] of known shape is super easy. + ## -> Even list[] won't do; its size varies, after all! + ## -> Reject modernity. Return to tuple[]. + if self._is_vector: + ## -> tuple()ify the np.array to respect tuple[] type annotation. + return tuple(np.array(value)) + + if self._is_matrix: + # Matrix Display Bug: Correctly Read Row-Major Values w/Reshape + return tuple( + map(tuple, np.array(value).flatten().reshape(self._matrix_shape)) + ) + + return value def __set__(self, bl_instance: BLInstance | None, value: typ.Any) -> None: if value == Signal.ResetEnumItems: diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index a127f80..5af6ab7 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -1,7 +1,7 @@ """Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system. Attributes: - ALL_UNIT_SYMBOLS: Maps all abbreviated Sympy symbols to their corresponding Sympy unit. + UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit. This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`. SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions. For simple `isinstance` checks, this should be preferred, as it is most performant. @@ -11,20 +11,35 @@ Attributes: """ import enum -import itertools +import functools import typing as typ +from fractions import Fraction +import jax +import jax.numpy as jnp import pydantic as pyd import sympy as sp import sympy.physics.units as spu import typing_extensions as typx from pydantic_core import core_schema as pyd_core_schema -SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix | spu.Quantity +from blender_maxwell import contracts as ct + +SympyType = ( + sp.Basic + | sp.Expr + | sp.MatrixBase + | sp.MutableDenseMatrix + | spu.Quantity + | spu.Dimension +) +#################### +# - Math Type +#################### class MathType(enum.StrEnum): - """Set identities encompassing common mathematical objects.""" + """Type identifiers that encompass common sets of mathematical objects.""" Bool = enum.auto() Integer = enum.auto() @@ -46,6 +61,7 @@ class MathType(enum.StrEnum): @staticmethod def from_expr(sp_obj: SympyType) -> type: + ## TODO: Support for sp.Matrix if isinstance(sp_obj, sp.logic.boolalg.Boolean): return MathType.Bool if sp_obj.is_integer: @@ -57,7 +73,7 @@ class MathType(enum.StrEnum): if sp_obj.is_complex: return MathType.Complex - msg = "Can't determine MathType from sympy object: {sp_obj}" + msg = f"Can't determine MathType from sympy object: {sp_obj}" raise ValueError(msg) @staticmethod @@ -67,24 +83,18 @@ class MathType(enum.StrEnum): int: MathType.Integer, float: MathType.Real, complex: MathType.Complex, - # jnp.int32: MathType.Integer, - # jnp.int64: MathType.Integer, - # jnp.float32: MathType.Real, - # jnp.float64: MathType.Real, - # jnp.complex64: MathType.Complex, - # jnp.complex128: MathType.Complex, - # jnp.bool_: MathType.Bool, }[dtype] - @staticmethod - def to_dtype(value: typ.Self) -> type: + @property + def pytype(self) -> type: + MT = MathType return { - MathType.Bool: bool, - MathType.Integer: int, - MathType.Rational: float, - MathType.Real: float, - MathType.Complex: complex, - }[value] + MT.Bool: bool, + MT.Integer: int, + MT.Rational: float, + MT.Real: float, + MT.Complex: complex, + }[self] @staticmethod def to_str(value: typ.Self) -> type: @@ -96,6 +106,116 @@ class MathType(enum.StrEnum): MathType.Complex: 'ℂ', }[value] + @staticmethod + def to_name(value: typ.Self) -> str: + return MathType.to_str(value) + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + MathType.to_name(self), + MathType.to_name(self), + MathType.to_icon(self), + i, + ) + + +class NumberSize1D(enum.StrEnum): + """Valid 1D-constrained shape.""" + + Scalar = enum.auto() + Vec2 = enum.auto() + Vec3 = enum.auto() + Vec4 = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: 'Scalar', + NS.Vec2: '2D', + NS.Vec3: '3D', + NS.Vec4: '4D', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: '', + NS.Vec2: '', + NS.Vec3: '', + NS.Vec4: '', + }[value] + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + NumberSize1D.to_name(self), + NumberSize1D.to_name(self), + NumberSize1D.to_icon(self), + i, + ) + + @staticmethod + def supports_shape(shape: tuple[int, ...] | None): + return shape is None or (len(shape) == 1 and shape[0] in [2, 3]) + + @staticmethod + def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: + NS = NumberSize1D + return { + None: NS.Scalar, + (2,): NS.Vec2, + (3,): NS.Vec3, + (4,): NS.Vec3, + }[shape] + + @property + def shape(self): + NS = NumberSize1D + return { + NS.Scalar: None, + NS.Vec2: (2,), + NS.Vec3: (3,), + NS.Vec3: (4,), + }[self] + + +#################### +# - Unit Dimensions +#################### +class DimsMeta(type): + def __getattr__(cls, attr: str) -> spu.Dimension: + if ( + attr in spu.definitions.dimension_definitions.__dir__() + and not attr.startswith('__') + ): + return getattr(spu.definitions.dimension_definitions, attr) + + raise AttributeError(name=attr, obj=Dims) + + +class Dims(metaclass=DimsMeta): + """Access `sympy.physics.units` dimensions with less hassle. + + Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. + + An `AttributeError` is raised if the unit cannot be found in `sympy`. + + Examples: + The objects returned are a direct alias to `sympy`, with less hassle: + ```python + assert Dims.length == ( + sympy.physics.units.definitions.dimension_definitions.length + ) + ``` + """ + #################### # - Units @@ -140,34 +260,16 @@ petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) +# Pressure +millibar = mbar = spu.Quantity('millibar', abbrev='mbar') +millibar.set_global_relative_scale_factor(spu.milli, spu.bar) -#################### -# - Sympy Printer -#################### -_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( - settings={ - 'abbrev': True, - } -) +hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 +hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) - -def sp_to_str(sp_obj: SympyType) -> str: - """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. - - This should be used whenever a **string for UI use** is needed from a `sympy` object. - - Notes: - This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. - For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. - - Parameters: - sp_obj: The `sympy` object to convert to a string. - - Returns: - A string representing the expression for human use. - _The string is not re-encodable to the expression._ - """ - return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) +UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { + unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) +} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} #################### @@ -175,7 +277,7 @@ def sp_to_str(sp_obj: SympyType) -> str: #################### ## TODO: Caching w/srepr'ed expression. ## TODO: An LFU cache could do better than an LRU. -def uses_units(expr: sp.Expr) -> bool: +def uses_units(sp_obj: SympyType) -> bool: """Determines if an expression uses any units. Notes: @@ -190,9 +292,10 @@ def uses_units(expr: sp.Expr) -> bool: Returns: Whether or not there are units used within the expression. """ - return any( - isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(expr) - ) + return sp_obj.has(spu.Quantity) + # return any( + # isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj) + # ) ## TODO: Caching w/srepr'ed expression. @@ -222,127 +325,11 @@ def get_units(expr: sp.Expr) -> set[spu.Quantity]: } -#################### -# - Sympy Expression Typing -#################### -ALL_UNIT_SYMBOLS: dict[sp.Symbol, spu.Quantity] = { - unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) -} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} +def parse_shape(sp_obj: SympyType) -> int | None: + if isinstance(sp_obj, sp.MatrixBase): + return sp_obj.shape - -#################### -# - Units <-> Scalars -#################### -def scale_to_unit(expr: sp.Expr, unit: spu.Quantity) -> sp.Expr: - """Convert an expression that uses units to a different unit, then strip all units. - - This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. - - Notes: - The unitless output is still an `sp.Expr`, which may contain ex. symbols. - - If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. - In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. - - Parameters: - expr: The unit-containing expression to convert. - unit_to: The unit that is converted to. - - Returns: - The unitless part of `expr`, after scaling the entire expression to `unit`. - - Raises: - ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. - """ - ## TODO: An LFU cache could do better than an LRU. - unitless_expr = spu.convert_to(expr, unit) / unit - if not uses_units(unitless_expr): - return unitless_expr - - msg = f'Expression "{expr}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' - raise ValueError(msg) - - -def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> sp.Number: - """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. - - Parameters: - unit_from: The unit that is converted from. - unit_to: The unit that is converted to. - - Returns: - The numerical scaling factor between the two units. - - Raises: - ValueError: If the two units don't share a common dimension. - """ - if unit_from.dimension == unit_to.dimension: - return scale_to_unit(unit_from, unit_to) - - msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them" - raise ValueError(msg) - - -#################### -# - Sympy -> Python -#################### -## TODO: Integrate SympyExpr for constraining to the output types. -def sympy_to_python_type(sym: sp.Symbol) -> type: - """Retrieve the Python type that is implied by a scalar `sympy` symbol. - - Arguments: - sym: A scalar sympy symbol. - - Returns: - A pure Python type. - """ - if sym.is_integer: - return int - if sym.is_rational or sym.is_real: - return float - if sym.is_complex: - return complex - - msg = f'Cannot find Python type for sympy symbol "{sym}". Check the assumptions on the expr (current expr assumptions: "{sym._assumptions}")' # noqa: SLF001 - raise ValueError(msg) - - -def sympy_to_python(scalar: sp.Basic) -> int | float | complex | tuple | list: - """Convert a scalar sympy expression to the directly corresponding Python type. - - Arguments: - scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. - For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. - - Returns: - A pure Python type that directly corresponds to the input scalar expression. - """ - if isinstance(scalar, sp.MatrixBase): - list_2d = [[sympy_to_python(el) for el in row] for row in scalar.tolist()] - - # Detect Row / Column Vector - ## When it's "actually" a 1D structure, flatten and return as tuple. - if 1 in scalar.shape: - return tuple(itertools.chain.from_iterable(list_2d)) - - return list_2d - if scalar.is_integer: - return int(scalar) - if scalar.is_rational or scalar.is_real: - return float(scalar) - if scalar.is_complex: - return complex(scalar) - - msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 - raise ValueError(msg) - - -def pretty_symbol(sym: sp.Symbol) -> str: - return f'{sym.name} ∈ ' + ( - 'ℂ' - if sym.is_complex - else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) - ) + return None #################### @@ -399,32 +386,42 @@ class _SympyExpr: raise ValueError(msg) from ex # Substitute Symbol -> Quantity - return expr.subs(ALL_UNIT_SYMBOLS) + return expr.subs(UNIT_BY_SYMBOL) - # def validate_from_expr(sp_obj: SympyType) -> SympyType: - # """Validate that a `sympy` object is a `SympyType`. + def validate_from_pytype( + sp_pytype: int | Fraction | float | complex, + ) -> SympyType | typ.Any: + """Parse and validate a pure Python type. - # In the static sense, this is a dummy function. + Parameters: + sp_str: A stringified `sympy` object, that will be parsed to a sympy type. + Before use, `isinstance(expr_str, str)` is checked. + If the object isn't a string, then the validation will be skipped. - # Parameters: - # sp_obj: A `sympy` object. + Returns: + Either a `sympy` object, if the input is parseable, or the same untouched object. - # Returns: - # The `sympy` object. + Raises: + ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. + """ + # Constrain to String + if not isinstance(sp_pytype, int | Fraction | float | complex): + return sp_pytype - # Raises: - # ValueError: If `sp_obj` is not a `sympy` object. - # """ - # if not (isinstance(sp_obj, SympyType)): - # msg = f'Value {sp_obj} is not a `sympy` expression' - # raise ValueError(msg) + if isinstance(sp_pytype, int): + return sp.Integer(sp_pytype) + if isinstance(sp_pytype, Fraction): + return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) + if isinstance(sp_pytype, float): + return sp.Float(sp_pytype) - # return sp_obj + # sp_pytype => Complex + return sp_pytype.real + sp.I * sp_pytype.imag sympy_expr_schema = pyd_core_schema.chain_schema( [ pyd_core_schema.no_info_plain_validator_function(validate_from_str), - # pyd_core_schema.no_info_plain_validator_function(validate_from_expr), + pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), pyd_core_schema.is_instance_schema(SympyType), ] ) @@ -441,6 +438,7 @@ SympyExpr = typx.Annotated[ sp.Basic, ## Treat all sympy types as sp.Basic _SympyExpr, ] +## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. def ConstrSympyExpr( # noqa: N802, PLR0913 @@ -503,10 +501,10 @@ def ConstrSympyExpr( # noqa: N802, PLR0913 ) if allowed_structures and not any( { + 'scalar': True, 'matrix': isinstance(expr, sp.MatrixBase), }[allowed_set] for allowed_set in allowed_structures - if allowed_structures != 'scalar' ): msgs.add( f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" @@ -566,6 +564,12 @@ IntSymbol: typ.TypeAlias = ConstrSympyExpr( allowed_sets={'integer'}, max_symbols=1, ) +RationalSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational'}, + max_symbols=1, +) RealSymbol: typ.TypeAlias = ConstrSympyExpr( allow_variables=True, allow_units=False, @@ -581,9 +585,10 @@ ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol # Unit +UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension + ## Technically a "unit expression", which includes compound types. -## Support for this is the killer feature compared to spu.Quantity. -UnitDimension: typ.TypeAlias = spu.Dimension +## Support for this is the reason to prefer over raw spu.Quantity. Unit: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=True, @@ -634,3 +639,656 @@ Real3DVector: typ.TypeAlias = ConstrSympyExpr( allowed_structures={'matrix'}, allowed_matrix_shapes={(3, 1)}, ) + + +#################### +# - Sympy Utilities: Printing +#################### +_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( + settings={ + 'abbrev': True, + } +) + + +def sp_to_str(sp_obj: SympyExpr) -> str: + """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. + + This should be used whenever a **string for UI use** is needed from a `sympy` object. + + Notes: + This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. + For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. + + Parameters: + sp_obj: The `sympy` object to convert to a string. + + Returns: + A string representing the expression for human use. + _The string is not re-encodable to the expression._ + """ + return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) + + +def pretty_symbol(sym: sp.Symbol) -> str: + return f'{sym.name} ∈ ' + ( + 'ℂ' + if sym.is_complex + else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) + ) + + +#################### +# - Unit Utilities +#################### +def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number: + """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. + + This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. + + Notes: + The unitless output is still an `sp.Expr`, which may contain ex. symbols. + + If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. + In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. + + Parameters: + expr: The unit-containing expression to convert. + unit_to: The unit that is converted to. + + Returns: + The unitless part of `expr`, after scaling the entire expression to `unit`. + + Raises: + ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. + """ + ## TODO: An LFU cache could do better than an LRU. + unitless_expr = spu.convert_to(sp_obj, unit) / unit + if not uses_units(unitless_expr): + return unitless_expr + + msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' + raise ValueError(msg) + + +def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number: + """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. + + Parameters: + unit_from: The unit that is converted from. + unit_to: The unit that is converted to. + + Returns: + The numerical scaling factor between the two units. + + Raises: + ValueError: If the two units don't share a common dimension. + """ + if unit_from.dimension == unit_to.dimension: + return scale_to_unit(unit_from, unit_to) + + msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them" + raise ValueError(msg) + + +_UNIT_STR_MAP = {sym.name: unit for sym, unit in UNIT_BY_SYMBOL.items()} + + +@functools.cache +def unit_str_to_unit(unit_str: str) -> Unit | None: + if unit_str in _UNIT_STR_MAP: + return _UNIT_STR_MAP[unit_str] + + msg = 'No valid unit for unit string {unit_str}' + raise ValueError(msg) + + +#################### +# - "Physical" Type +#################### +class PhysicalType(enum.StrEnum): + """Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type.""" + + # Global + Time = enum.auto() + Angle = enum.auto() + SolidAngle = enum.auto() + ## TODO: Some kind of 3D-specific orientation ex. a quaternion + Freq = enum.auto() + AngFreq = enum.auto() ## rad*hertz + # Cartesian + Length = enum.auto() + Area = enum.auto() + Volume = enum.auto() + # Mechanical + Vel = enum.auto() + Accel = enum.auto() + Mass = enum.auto() + Force = enum.auto() + Pressure = enum.auto() + # Energy + Work = enum.auto() ## joule + Power = enum.auto() ## watt + PowerFlux = enum.auto() ## watt + Temp = enum.auto() + # Electrodynamics + Current = enum.auto() ## ampere + CurrentDensity = enum.auto() + Charge = enum.auto() ## coulomb + Voltage = enum.auto() + Capacitance = enum.auto() ## farad + Impedance = enum.auto() ## ohm + Conductance = enum.auto() ## siemens + Conductivity = enum.auto() ## siemens / length + MFlux = enum.auto() ## weber + MFluxDensity = enum.auto() ## tesla + Inductance = enum.auto() ## henry + EField = enum.auto() + HField = enum.auto() + # Luminal + LumIntensity = enum.auto() + LumFlux = enum.auto() + Luminance = enum.auto() + Illuminance = enum.auto() + # Optics + OrdinaryWaveVector = enum.auto() + AngularWaveVector = enum.auto() + PoyntingVector = enum.auto() + + @property + def unit_dim(self): + PT = PhysicalType + return { + # Global + PT.Time: Dims.time, + PT.Angle: Dims.angle, + PT.SolidAngle: Dims.steradian, ## MISSING + PT.Freq: Dims.frequency, + PT.AngFreq: Dims.angle * Dims.frequency, + # Cartesian + PT.Length: Dims.length, + PT.Area: Dims.length**2, + PT.Volume: Dims.length**3, + # Mechanical + PT.Vel: Dims.length / Dims.time, + PT.Accel: Dims.length / Dims.time**2, + PT.Mass: Dims.mass, + PT.Force: Dims.force, + PT.Pressure: Dims.pressure, + # Energy + PT.Work: Dims.energy, + PT.Power: Dims.power, + PT.PowerFlux: Dims.power / Dims.length**2, + PT.Temp: Dims.temperature, + # Electrodynamics + PT.Current: Dims.current, + PT.CurrentDensity: Dims.current / Dims.length**2, + PT.Charge: Dims.charge, + PT.Voltage: Dims.voltage, + PT.Capacitance: Dims.capacitance, + PT.Impedance: Dims.impedance, + PT.Conductance: Dims.conductance, + PT.Conductivity: Dims.conductance / Dims.length, + PT.MFlux: Dims.magnetic_flux, + PT.MFluxDensity: Dims.magnetic_density, + PT.Inductance: Dims.inductance, + PT.EField: Dims.voltage / Dims.length, + PT.HField: Dims.current / Dims.length, + # Luminal + PT.LumIntensity: Dims.luminous_intensity, + PT.LumFlux: Dims.luminous_intensity * Dims.steradian, + PT.Illuminance: Dims.luminous_intensity / Dims.length**2, + # Optics + PT.OrdinaryWaveVector: Dims.frequency, + PT.AngularWaveVector: Dims.angle * Dims.frequency, + PT.PoyntingVector: Dims.power / Dims.length**2, + } + + @property + def default_unit(self) -> list[Unit]: + PT = PhysicalType + return { + # Global + PT.Time: spu.picosecond, + PT.Angle: spu.radian, + PT.SolidAngle: spu.steradian, + PT.Freq: terahertz, + PT.AngFreq: spu.radian * terahertz, + # Cartesian + PT.Length: spu.micrometer, + PT.Area: spu.um**2, + PT.Volume: spu.um**3, + # Mechanical + PT.Vel: spu.um / spu.second, + PT.Accel: spu.um / spu.second, + PT.Mass: spu.microgram, + PT.Force: micronewton, + PT.Pressure: millibar, + # Energy + PT.Work: spu.joule, + PT.Power: spu.watt, + PT.PowerFlux: spu.watt / spu.meter**2, + PT.Temp: spu.kelvin, + # Electrodynamics + PT.Current: spu.ampere, + PT.CurrentDensity: spu.ampere / spu.meter**2, + PT.Charge: spu.coulomb, + PT.Voltage: spu.volt, + PT.Capacitance: spu.farad, + PT.Impedance: spu.ohm, + PT.Conductance: spu.siemens, + PT.Conductivity: spu.siemens / spu.micrometer, + PT.MFlux: spu.weber, + PT.MFluxDensity: spu.tesla, + PT.Inductance: spu.henry, + PT.EField: spu.volt / spu.micrometer, + PT.HField: spu.ampere / spu.micrometer, + # Luminal + PT.LumIntensity: spu.candela, + PT.LumFlux: spu.candela * spu.steradian, + PT.Illuminance: spu.candela / spu.meter**2, + # Optics + PT.OrdinaryWaveVector: terahertz, + PT.AngularWaveVector: spu.radian * terahertz, + }[self] + + @property + def valid_units(self) -> list[Unit]: + PT = PhysicalType + return { + # Global + PT.Time: [ + femtosecond, + spu.picosecond, + spu.nanosecond, + spu.microsecond, + spu.millisecond, + spu.second, + spu.minute, + spu.hour, + spu.day, + ], + PT.Angle: [ + spu.radian, + spu.degree, + ], + PT.SolidAngle: [ + spu.steradian, + ], + PT.Freq: ( + _valid_freqs := [ + spu.hertz, + kilohertz, + megahertz, + gigahertz, + terahertz, + petahertz, + exahertz, + ] + ), + PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], + # Cartesian + PT.Length: ( + _valid_lens := [ + spu.picometer, + spu.angstrom, + spu.nanometer, + spu.micrometer, + spu.millimeter, + spu.centimeter, + spu.meter, + spu.inch, + spu.foot, + spu.yard, + spu.mile, + ] + ), + PT.Area: [_unit**2 for _unit in _valid_lens], + PT.Volume: [_unit**3 for _unit in _valid_lens], + # Mechanical + PT.Vel: [_unit / spu.second for _unit in _valid_lens], + PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], + PT.Mass: [ + spu.electron_rest_mass, + spu.dalton, + spu.microgram, + spu.milligram, + spu.gram, + spu.kilogram, + spu.metric_ton, + ], + PT.Force: [ + spu.kg * spu.meter / spu.second**2, + nanonewton, + micronewton, + millinewton, + spu.newton, + ], + PT.Pressure: [ + millibar, + spu.bar, + spu.pascal, + hectopascal, + spu.atmosphere, + spu.psi, + spu.mmHg, + spu.torr, + ], + # Energy + PT.Work: [ + spu.electronvolt, + spu.joule, + ], + PT.Power: [ + spu.watt, + ], + PT.PowerFlux: [ + spu.watt / spu.meter**2, + ], + PT.Temp: [ + spu.kelvin, + ], + # Electrodynamics + PT.Current: [ + spu.ampere, + ], + PT.CurrentDensity: [ + spu.ampere / spu.meter**2, + ], + PT.Charge: [ + spu.coulomb, + ], + PT.Voltage: [ + spu.volt, + ], + PT.Capacitance: [ + spu.farad, + ], + PT.Impedance: [ + spu.ohm, + ], + PT.Conductance: [ + spu.siemens, + ], + PT.Conductivity: [ + spu.siemens / spu.micrometer, + spu.siemens / spu.meter, + ], + PT.MFlux: [ + spu.weber, + ], + PT.MFluxDensity: [ + spu.tesla, + ], + PT.Inductance: [ + spu.henry, + ], + PT.EField: [ + spu.volt / spu.micrometer, + spu.volt / spu.meter, + ], + PT.HField: [ + spu.ampere / spu.micrometer, + spu.ampere / spu.meter, + ], + # Luminal + PT.LumIntensity: [ + spu.candela, + ], + PT.LumFlux: [ + spu.candela * spu.steradian, + ], + PT.Illuminance: [ + spu.candela / spu.meter**2, + ], + # Optics + PT.OrdinaryWaveVector: _valid_freqs, + PT.AngularWaveVector: [spu.radian * _unit for _unit in _valid_freqs], + }[self] + + @staticmethod + def from_unit(unit: Unit) -> list[Unit]: + for physical_type in list[PhysicalType]: + if unit in physical_type.valid_units: + return physical_type + + msg = f'No PhysicalType found for unit {unit}' + raise ValueError(msg) + + @property + def valid_shapes(self): + PT = PhysicalType + overrides = { + # Cartesian + PT.Length: [None, (2,), (3,)], + # Mechanical + PT.Vel: [None, (2,), (3,)], + PT.Accel: [None, (2,), (3,)], + PT.Force: [None, (2,), (3,)], + # Energy + PT.Work: [None, (2,), (3,)], + PT.PowerFlux: [None, (2,), (3,)], + # Electrodynamics + PT.CurrentDensity: [None, (2,), (3,)], + PT.MFluxDensity: [None, (2,), (3,)], + PT.EField: [None, (2,), (3,)], + PT.HField: [None, (2,), (3,)], + # Luminal + PT.LumFlux: [None, (2,), (3,)], + # Optics + PT.OrdinaryWaveVector: [None, (2,), (3,)], + PT.AngularWaveVector: [None, (2,), (3,)], + PT.PoyntingVector: [None, (2,), (3,)], + } + + return overrides.get(self, [None]) + + @property + def valid_mathtypes(self) -> list[MathType]: + """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. + + Generally, all unit quantities are real, in the algebraic mathematical sense. + However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. + This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. + In general, the value is a phasor. + + While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. + + Notes: + - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. + - **Current**/**Voltage**: The imaginary part represents the phase. + This also holds for any downstream units. + - **Charge**: Generally, it is real. + However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: + - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. + - **Poynting**: The imaginary part represents the oscillation in the power flux over time. + + """ + MT = MathType + PT = PhysicalType + overrides = { + # Cartesian + PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + # Mechanical + # Energy + # Electrodynamics + PT.Current: [MT.Real, MT.Complex], ## Im -> Phase + PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase + PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase + PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase + PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance + PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction + PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase + PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.EField: [MT.Real, MT.Complex], ## Im -> Phase + PT.HField: [MT.Real, MT.Complex], ## Im -> Phase + # Luminal + # Optics + PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase + PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase + PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power + } + + return overrides.get(self, [MT.Real]) + + @staticmethod + def to_name(value: typ.Self) -> str: + return sp_to_str(value.unit_dim) + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + PT = PhysicalType + return ( + str(self), + PT.to_name(self), + PT.to_name(self), + PT.to_icon(self), + i, + ) + + +#################### +# - Standard Unit Systems +#################### +UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit] + +_PT = PhysicalType +UNITS_SI: UnitSystem = { + # Global + _PT.Time: spu.second, + _PT.Angle: spu.radian, + _PT.SolidAngle: spu.steradian, + _PT.Freq: spu.hertz, + _PT.AngFreq: spu.radian * spu.hertz, + # Cartesian + _PT.Length: spu.meter, + _PT.Area: spu.meter**2, + _PT.Volume: spu.meter**3, + # Mechanical + _PT.Vel: spu.meter / spu.second, + _PT.Accel: spu.meter / spu.second**2, + _PT.Mass: spu.kilogram, + _PT.Force: spu.newton, + # Energy + _PT.Work: spu.joule, + _PT.Power: spu.watt, + _PT.PowerFlux: spu.watt / spu.meter**2, + _PT.Temp: spu.kelvin, + # Electrodynamics + _PT.Current: spu.ampere, + _PT.CurrentDensity: spu.ampere / spu.meter**2, + _PT.Capacitance: spu.farad, + _PT.Impedance: spu.ohm, + _PT.Conductance: spu.siemens, + _PT.Conductivity: spu.siemens / spu.meter, + _PT.MFlux: spu.weber, + _PT.MFluxDensity: spu.tesla, + _PT.Inductance: spu.henry, + _PT.EField: spu.volt / spu.meter, + _PT.HField: spu.ampere / spu.meter, + # Luminal + _PT.LumIntensity: spu.candela, + _PT.LumFlux: lumen, + _PT.Illuminance: spu.lux, + # Optics + _PT.OrdinaryWaveVector: spu.hertz, + _PT.AngularWaveVector: spu.radian * spu.hertz, + _PT.PoyntingVector: spu.watt / spu.meter**2, +} + + +#################### +# - Sympy Utilities: Cast to Python +#################### +def sympy_to_python( + scalar: sp.Basic, use_jax_array: bool = False +) -> int | float | complex | tuple | jax.Array: + """Convert a scalar sympy expression to the directly corresponding Python type. + + Arguments: + scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. + For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. + + Returns: + A pure Python type that directly corresponds to the input scalar expression. + """ + if isinstance(scalar, sp.MatrixBase): + # Detect Single Column Vector + ## --> Flatten to Single Row Vector + if len(scalar.shape) == 2 and scalar.shape[1] == 1: + _scalar = scalar.T + else: + _scalar = scalar + + # Convert to Tuple of Tuples + matrix = tuple( + [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] + ) + + # Detect Single Row Vector + ## --> This could be because the scalar had it. + ## --> This could also be because we flattened a column vector. + ## Either way, we should strip the pointless dimensions. + if len(matrix) == 1: + return matrix[0] if not use_jax_array else jnp.array(matrix[0]) + + return matrix if not use_jax_array else jnp.array(matrix) + if scalar.is_integer: + return int(scalar) + if scalar.is_rational or scalar.is_real: + return float(scalar) + if scalar.is_complex: + return complex(scalar) + + msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 + raise ValueError(msg) + + +#################### +# - Convert to Unit System +#################### +def _flat_unit_system_units(unit_system: UnitSystem) -> SympyExpr: + return list(unit_system.values()) + + +def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: + """Convert an expression to the units of a given unit system, with appropriate scaling.""" + return spu.convert_to(sp_obj, _flat_unit_system_units(unit_system)) + + +def scale_to_unit_system( + sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False +) -> int | float | complex | tuple | jax.Array: + """Convert an expression to the units of a given unit system, then strip all units of the unit system. + + Unit stripping is "dumb": Substitute any `sympy` object in `unit_system.values()` with `1`. + Afterwards, it is converted to an appropriate Python type. + + Notes: + For stability, and performance, reasons, this should only be used at the very last stage. + + Regarding performance: **This is not a fast function**. + + Parameters: + sp_obj: An arbitrary sympy object, presumably with units. + unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. + Note that, in this context, only `unit_system.values()` is used. + + Returns: + An appropriate pure Python type, after scaling to the unit system and stripping all units away. + + If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. + """ + return sympy_to_python( + convert_to_unit_system(sp_obj, unit_system).subs( + {unit: 1 for unit in unit_system.values()} + ), + use_jax_array=use_jax_array, + ) diff --git a/src/blender_maxwell/utils/jarray.py b/src/blender_maxwell/utils/jarray.py deleted file mode 100644 index 859e2a3..0000000 --- a/src/blender_maxwell/utils/jarray.py +++ /dev/null @@ -1,61 +0,0 @@ -import dataclasses -import typing as typ -from types import MappingProxyType - -import jax -import jax.numpy as jnp - -# import jaxtyping as jtyp -import sympy.physics.units as spu -import xarray - -from . import logger - -log = logger.get(__name__) - -DimName: typ.TypeAlias = str -Number: typ.TypeAlias = int | float | complex -NumberRange: typ.TypeAlias = jax.Array - - -@dataclasses.dataclass(kw_only=True) -class JArray: - """Very simple wrapper for JAX arrays, which includes information about the dimension names and bounds.""" - - array: jax.Array - dims: dict[DimName, NumberRange] - dim_units: dict[DimName, spu.Quantity] - - #################### - # - Constructor - #################### - @classmethod - def from_xarray( - cls, - xarr: xarray.DataArray, - dim_units: dict[DimName, spu.Quantity] = MappingProxyType({}), - sort_axis: int = -1, - ) -> typ.Self: - return cls( - array=jnp.sort(jnp.array(xarr.data), axis=sort_axis), - dims={ - dim_name: jnp.array(xarr.get_index(dim_name).values) - for dim_name in xarr.dims - }, - dim_units={dim_name: dim_units.get(dim_name) for dim_name in xarr.dims}, - ) - - def idx(self, dim_name: DimName, dim_value: Number) -> int: - found_idx = jnp.searchsorted(self.dims[dim_name], dim_value) - if found_idx == 0: - return found_idx - if found_idx == len(self.dims[dim_name]): - return found_idx - 1 - - left = self.dims[dim_name][found_idx - 1] - right = self.dims[dim_name][found_idx - 1] - return found_idx - 1 if (dim_value - left) <= (right - dim_value) else found_idx - - @property - def dtype(self) -> jnp.dtype: - return self.array.dtype diff --git a/src/blender_maxwell/utils/pydantic_sympy.py b/src/blender_maxwell/utils/pydantic_sympy.py deleted file mode 100644 index 781c0b1..0000000 --- a/src/blender_maxwell/utils/pydantic_sympy.py +++ /dev/null @@ -1,164 +0,0 @@ -import typing as typ - -import pydantic as pyd -import sympy as sp -import sympy.physics.units as spu -import typing_extensions as typx -from pydantic_core import core_schema as pyd_core_schema - -from . import extra_sympy_units as spux - -#################### -# - Missing Basics -#################### -AllowedSympyExprs = sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix -Complex = typx.Annotated[ - complex, - pyd.GetPydanticSchema( - lambda tp, handler: pyd_core_schema.no_info_after_validator_function( - lambda x: x, handler(tp) - ) - ), -] - - -#################### -# - Custom Pydantic Type for sp.Expr -#################### -class _SympyExpr: - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: AllowedSympyExprs, - _handler: pyd.GetCoreSchemaHandler, - ) -> pyd_core_schema.CoreSchema: - def validate_from_str(value: str) -> AllowedSympyExprs: - if not isinstance(value, str): - return value - - try: - expr = sp.sympify(value) - except ValueError as ex: - msg = f'Value {value} is not a `sympify`able string' - raise ValueError(msg) from ex - - return expr.subs(spux.ALL_UNIT_SYMBOLS) - - def validate_from_expr(value: AllowedSympyExprs) -> AllowedSympyExprs: - if not (isinstance(value, sp.Expr | sp.MatrixBase)): - msg = f'Value {value} is not a `sympy` expression' - raise ValueError(msg) - - return value - - sympy_expr_schema = pyd_core_schema.chain_schema( - [ - pyd_core_schema.no_info_plain_validator_function(validate_from_str), - pyd_core_schema.no_info_plain_validator_function(validate_from_expr), - pyd_core_schema.is_instance_schema(AllowedSympyExprs), - ] - ) - return pyd_core_schema.json_or_python_schema( - json_schema=sympy_expr_schema, - python_schema=sympy_expr_schema, - serialization=pyd_core_schema.plain_serializer_function_ser_schema( - lambda instance: sp.srepr(instance) - ), - ) - - -#################### -# - Configurable Expression Validation -#################### -SympyExpr = typx.Annotated[ - AllowedSympyExprs, - _SympyExpr, -] - - -def ConstrSympyExpr( - # Feature Class - allow_variables: bool = True, - allow_units: bool = True, - # Structure Class - allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] - | None = None, - allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, - # Element Class - allowed_symbols: set[sp.Symbol] | None = None, - allowed_units: set[spu.Quantity] | None = None, - # Shape Class - allowed_matrix_shapes: set[tuple[int, int]] | None = None, -): - ## See `sympy` predicates: - ## - - def validate_expr(expr: AllowedSympyExprs): - if not (isinstance(expr, sp.Expr | sp.MatrixBase),): - ## NOTE: Must match AllowedSympyExprs union elements. - msg = f"expr '{expr}' is not an allowed Sympy expression ({AllowedSympyExprs})" - raise ValueError(msg) - - msgs = set() - - # Validate Feature Class - if (not allow_variables) and (len(expr.free_symbols) > 0): - msgs.add( - f'allow_variables={allow_variables} does not match expression {expr}.' - ) - if (not allow_units) and spux.uses_units(expr): - msgs.add(f'allow_units={allow_units} does not match expression {expr}.') - - # Validate Structure Class - if ( - allowed_sets - and isinstance(expr, sp.Expr) - and not any( - { - 'integer': expr.is_integer, - 'rational': expr.is_rational, - 'real': expr.is_real, - 'complex': expr.is_complex, - }[allowed_set] - for allowed_set in allowed_sets - ) - ): - msgs.add( - f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - if allowed_structures and not any( - { - 'matrix': isinstance(expr, sp.MatrixBase), - }[allowed_set] - for allowed_set in allowed_structures - if allowed_structures != 'scalar' - ): - msgs.add( - f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - - # Validate Element Class - if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): - msgs.add( - f'allowed_symbols={allowed_symbols} does not match expression {expr}' - ) - if allowed_units and spux.get_units(expr).issubset(allowed_units): - msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') - - # Validate Shape Class - if ( - allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) - ) and expr.shape not in allowed_matrix_shapes: - msgs.add( - f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' - ) - - # Error or Return - if msgs: - raise ValueError(str(msgs)) - return expr - - return typx.Annotated[ - AllowedSympyExprs, - _SympyExpr, - pyd.AfterValidator(validate_expr), - ] diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py index 6e3dfc3..cf9dac1 100644 --- a/src/blender_maxwell/utils/serialize.py +++ b/src/blender_maxwell/utils/serialize.py @@ -141,7 +141,7 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any: is_representation(obj) and obj[0] == TypeID.SympyType ): obj_value = obj[2] - return sp.sympify(obj_value).subs(spux.ALL_UNIT_SYMBOLS) + return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL) if hasattr(_type, 'parse_as_msgspec') and ( is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj]