From 4f6bd8e990a09da73967aff4764bccb67d53b594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Mon, 15 Apr 2024 17:43:06 +0200 Subject: [PATCH] refactor: Revamped serialization (non-working) --- TODO.md | 6 +- .../node_trees/maxwell_sim_nodes/bl_cache.py | 29 ++- .../contracts/schemas/managed_obj.py | 22 -- .../contracts/schemas/managed_obj_def.py | 10 - .../contracts/schemas/socket_def.py | 20 +- .../managed_objs/__init__.py | 5 +- .../maxwell_sim_nodes/managed_objs/base.py | 58 +++++ .../managed_objs/managed_bl_image.py | 18 +- .../managed_objs/managed_bl_mesh.py | 3 +- .../managed_objs/managed_bl_modifier.py | 6 +- .../maxwell_sim_nodes/nodes/base.py | 45 ++-- .../maxwell_sim_nodes/sockets/base.py | 29 ++- .../utils/extra_sympy_units.py | 2 + src/blender_maxwell/utils/serialize.py | 215 ++++++++++-------- 14 files changed, 276 insertions(+), 192 deletions(-) create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/base.py diff --git a/TODO.md b/TODO.md index 3508517..4282714 100644 --- a/TODO.md +++ b/TODO.md @@ -547,9 +547,9 @@ We need support for arbitrary objects, but still backed by the persistance seman ### Parallel Features - [x] Move serialization work to a `utils`. -- [ ] Also make ENCODER a function that can shortcut the easy cases. -- [ ] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions. -- [ ] Define a superclass for `SocketDef` and make everyone inherit from it +- [x] Also make ENCODER a function that can shortcut the easy cases. +- [x] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions. +- [x] Define a superclass for `SocketDef` and make everyone inherit from it - [ ] Collect with a `BL_SOCKET_DEFS` object, instead of manually from `__init__.py`s - [ ] Add support for `.msgspec_*()` methods, so that we remove the dependency on sockets from the serialization module. diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_cache.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_cache.py index ecdfa43..0aa16a5 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_cache.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/bl_cache.py @@ -65,7 +65,7 @@ class KeyedCache: self, func: typ.Callable, exclude: set[str], - serialize: set[str], + encode: set[str], ): # Function Information self.func: typ.Callable = func @@ -74,7 +74,7 @@ class KeyedCache: # Arg -> Key Information self.exclude: set[str] = exclude self.include: set[str] = set(self.func_sig.parameters.keys()) - exclude - self.serialize: set[str] = serialize + self.encode: set[str] = encode # Cache Information self.key_schema: tuple[str, ...] = tuple( @@ -102,8 +102,8 @@ class KeyedCache: [ ( arg_value - if arg_name not in self.serialize - else ENCODER.encode(arg_value) + if arg_name not in self.encode + else serialize.encode(arg_value) ) for arg_name, arg_value in arguments.items() if arg_name in self.include @@ -153,8 +153,8 @@ class KeyedCache: # Compute Keys to Invalidate arguments_hashable = { - arg_name: ENCODER.encode(arg_value) - if arg_name in self.serialize and arg_name not in wildcard_arguments + arg_name: serialize.encode(arg_value) + if arg_name in self.encode and arg_name not in wildcard_arguments else arg_value for arg_name, arg_value in arguments.items() } @@ -168,12 +168,12 @@ class KeyedCache: cache.pop(key) -def keyed_cache(exclude: set[str], serialize: set[str] = frozenset()) -> typ.Callable: +def keyed_cache(exclude: set[str], encode: set[str] = frozenset()) -> typ.Callable: def decorator(func: typ.Callable) -> typ.Callable: return KeyedCache( func, exclude=exclude, - serialize=serialize, + encode=encode, ) return decorator @@ -219,9 +219,6 @@ class CachedBLProperty: inspect.signature(getter_method).return_annotation if persist else None ) - # Check Non-Empty Type Annotation - ## For now, just presume that all types can be encoded/decoded. - # Check Non-Empty Type Annotation ## For now, just presume that all types can be encoded/decoded. if self._type is not None and self._type is inspect.Signature.empty: @@ -283,7 +280,7 @@ class CachedBLProperty: self._persist and (encoded_value := getattr(bl_instance, self._bl_prop_name)) != '' ): - value = decode_any(self._type, encoded_value) + value = serialize.decode(self._type, encoded_value) cache_nopersist[self._bl_prop_name] = value return value @@ -294,7 +291,7 @@ class CachedBLProperty: cache_nopersist[self._bl_prop_name] = value if self._persist: setattr( - bl_instance, self._bl_prop_name, ENCODER.encode(value).decode('utf-8') + bl_instance, self._bl_prop_name, serialize.encode(value).decode('utf-8') ) return value @@ -466,7 +463,7 @@ class BLField: raise TypeError(msg) # Define Blender Property (w/Update Sync) - encoded_default_value = ENCODER.encode(self._default_value).decode('utf-8') + encoded_default_value = serialize.encode(self._default_value).decode('utf-8') log.debug( '%s set to StringProperty w/default "%s" and no_update="%s"', bl_attr_name, @@ -487,14 +484,14 @@ class BLField: ## 2. Retrieve bpy.props.StringProperty string. ## 3. Decode using annotated type. def getter(_self: BLInstance) -> AttrType: - return decode_any(AttrType, getattr(_self, bl_attr_name)) + return serialize.decode(AttrType, getattr(_self, bl_attr_name)) ## Setter: ## 1. Initialize bpy.props.StringProperty to Default (if undefined). ## 3. Encode value (implicitly using the annotated type). ## 2. Set bpy.props.StringProperty string. def setter(_self: BLInstance, value: AttrType) -> None: - encoded_value = ENCODER.encode(value).decode('utf-8') + encoded_value = serialize.encode(value).decode('utf-8') log.debug( 'Writing BLField attr "%s" w/encoded value: %s', bl_attr_name, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj.py index a589239..e69de29 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj.py @@ -1,22 +0,0 @@ -import typing as typ - -from ..bl import ManagedObjName -from ..managed_obj_type import ManagedObjType - - -class ManagedObj(typ.Protocol): - managed_obj_type: ManagedObjType - - def __init__( - self, - name: ManagedObjName, - ): ... - - @property - def name(self) -> str: ... - @name.setter - def name(self, value: str): ... - - def free(self): ... - - def bl_select(self): ... diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj_def.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj_def.py index 3405a81..e69de29 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj_def.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/managed_obj_def.py @@ -1,10 +0,0 @@ -import typing as typ - -import pydantic as pyd - -from .managed_obj import ManagedObj - - -class ManagedObjDef(pyd.BaseModel): - mk: typ.Callable[[str], ManagedObj] - name_prefix: str = '' diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/socket_def.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/socket_def.py index 33763db..e9818ee 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/socket_def.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/schemas/socket_def.py @@ -1,12 +1,26 @@ +import abc import typing as typ import bpy +import pydantic as pyd +from .....utils import serialize from ..socket_types import SocketType -@typ.runtime_checkable -class SocketDef(typ.Protocol): +class SocketDef(pyd.BaseModel, abc.ABC): socket_type: SocketType - def init(self, bl_socket: bpy.types.NodeSocket) -> None: ... + @abc.abstractmethod + def init(self, bl_socket: bpy.types.NodeSocket) -> None: + """Initializes a real Blender node socket from this socket definition.""" + + #################### + # - Serialization + #################### + def dump_as_msgspec(self) -> serialize.NaiveRepresentation: + return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()] + + @staticmethod + def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self: + return SocketDef.__subclasses__[obj[1]](**obj[2]) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/__init__.py index 661143c..f3c0422 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/__init__.py @@ -1,5 +1,7 @@ import typing as typ +from .base import ManagedObj + # from .managed_bl_empty import ManagedBLEmpty from .managed_bl_image import ManagedBLImage @@ -10,9 +12,8 @@ from .managed_bl_mesh import ManagedBLMesh # from .managed_bl_volume import ManagedBLVolume from .managed_bl_modifier import ManagedBLModifier -ManagedObj: typ.TypeAlias = ManagedBLImage | ManagedBLMesh | ManagedBLModifier - __all__ = [ + 'ManagedObj', #'ManagedBLEmpty', 'ManagedBLImage', #'ManagedBLCollection', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/base.py new file mode 100644 index 0000000..47c6222 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/base.py @@ -0,0 +1,58 @@ +import abc +import typing as typ + +from ....utils import serialize +from .. import contracts as ct + + +class ManagedObj(abc.ABC): + managed_obj_type: ct.ManagedObjType + + @abc.abstractmethod + def __init__( + self, + name: ct.ManagedObjName, + ): + """Initializes the managed object with a unique name.""" + + #################### + # - Properties + #################### + @property + @abc.abstractmethod + def name(self) -> str: + """Retrieve the name of the managed object.""" + + @name.setter + @abc.abstractmethod + def name(self, value: str) -> None: + """Retrieve the name of the managed object.""" + + #################### + # - Methods + #################### + @abc.abstractmethod + def free(self) -> None: + """Cleanup the resources managed by the managed object.""" + + @abc.abstractmethod + def bl_select(self) -> None: + """Select the managed object in Blender, if such an operation makes sense.""" + + @abc.abstractmethod + def hide_preview(self) -> None: + """Select the managed object in Blender, if such an operation makes sense.""" + + #################### + # - Serialization + #################### + def dump_as_msgspec(self) -> serialize.NaiveRepresentation: + return [ + serialize.TypeID.ManagedObj, + self.__class__.__name__, + (self.managed_obj_type, self.name), + ] + + @staticmethod + def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self: + return ManagedObj.__subclasses__[obj[1]](**obj[2]) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py index ad99c0d..8c4dc9c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_image.py @@ -12,6 +12,7 @@ import typing_extensions as typx from ....utils import logger from .. import contracts as ct +from . import base log = logger.get(__name__) @@ -76,7 +77,7 @@ def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None): return rgba_image_from_xyzf__grayscale(xyz_freq) -class ManagedBLImage(ct.schemas.ManagedObj): +class ManagedBLImage(base.ManagedObj): managed_obj_type = ct.ManagedObjType.ManagedBLImage _bl_image_name: str @@ -181,6 +182,9 @@ class ManagedBLImage(ct.schemas.ManagedObj): if bl_image := bpy.data.images.get(self.name): self.preview_space.image = bl_image + def hide_preview(self) -> None: + self.preview_space.image = None + #################### # - Image Geometry #################### @@ -269,12 +273,12 @@ class ManagedBLImage(ct.schemas.ManagedObj): ) # log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start) - #log.debug( - # 'Creating MPL Axes (aspect=%f, width=%f, height=%f)', - # aspect_ratio, - # _width_inches, - # _height_inches, - #) + # log.debug( + # 'Creating MPL Axes (aspect=%f, width=%f, height=%f)', + # aspect_ratio, + # _width_inches, + # _height_inches, + # ) # Create MPL Figure, Axes, and Compute Figure Geometry fig, ax = plt.subplots( figsize=[_width_inches, _height_inches], diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py index 5402d49..cff2487 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py @@ -7,6 +7,7 @@ import numpy as np from ....utils import logger from .. import contracts as ct from .managed_bl_collection import managed_collection, preview_collection +from . import base log = logger.get(__name__) @@ -14,7 +15,7 @@ log = logger.get(__name__) #################### # - BLMesh #################### -class ManagedBLMesh(ct.schemas.ManagedObj): +class ManagedBLMesh(base.ManagedObj): managed_obj_type = ct.ManagedObjType.ManagedBLMesh _bl_object_name: str | None = None diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py index 2e84607..e5a4edd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py @@ -8,6 +8,7 @@ import typing_extensions as typx from ....utils import analyze_geonodes, logger from .. import bl_socket_map from .. import contracts as ct +from . import base log = logger.get(__name__) @@ -160,7 +161,7 @@ def write_modifier( #################### # - ManagedObj #################### -class ManagedBLModifier(ct.schemas.ManagedObj): +class ManagedBLModifier(base.ManagedObj): managed_obj_type = ct.ManagedObjType.ManagedBLModifier _modifier_name: str | None = None @@ -185,6 +186,9 @@ class ManagedBLModifier(ct.schemas.ManagedObj): def __init__(self, name: str): self.name = name + def bl_select(self) -> None: pass + def hide_preview(self) -> None: pass + #################### # - Deallocation #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index e86cd6a..cc0ddb0 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -55,8 +55,8 @@ class MaxwellSimNode(bpy.types.Node): presets: typ.ClassVar = MappingProxyType({}) # Managed Objects - managed_obj_defs: typ.ClassVar[ - dict[ct.ManagedObjName, ct.schemas.ManagedObjDef] + managed_obj_types: typ.ClassVar[ + dict[ct.ManagedObjName, type[_managed_objs.ManagedObj]] ] = MappingProxyType({}) #################### @@ -222,7 +222,7 @@ class MaxwellSimNode(bpy.types.Node): #################### @events.on_value_changed( prop_name='sim_node_name', - props={'sim_node_name', 'managed_objs', 'managed_obj_defs'}, + props={'sim_node_name', 'managed_objs', 'managed_obj_types'}, ) def _on_sim_node_name_changed(self, props: dict): log.info( @@ -233,9 +233,8 @@ class MaxwellSimNode(bpy.types.Node): ) # Set Name of Managed Objects - for mobj_id, mobj in props['managed_objs'].items(): - mobj_def = props['managed_obj_defs'][mobj_id] - mobj.name = mobj_def.name_prefix + props['sim_node_name'] + for mobj in props['managed_objs'].values(): + mobj.name = props['sim_node_name'] @events.on_value_changed(prop_name='active_socket_set') def _on_socket_set_changed(self): @@ -282,9 +281,7 @@ class MaxwellSimNode(bpy.types.Node): def _on_preview_changed(self, props): if not props['preview_active']: for mobj in self.managed_objs.values(): - if isinstance(mobj, _managed_objs.ManagedBLMesh): - ## TODO: This is a Workaround - mobj.hide_preview() + mobj.hide_preview() @events.on_enable_lock() def _on_enabled_lock(self): @@ -460,33 +457,17 @@ class MaxwellSimNode(bpy.types.Node): #################### # - Managed Objects #################### - managed_bl_meshes: dict[str, _managed_objs.ManagedBLMesh] = bl_cache.BLField({}) - managed_bl_images: dict[str, _managed_objs.ManagedBLImage] = bl_cache.BLField({}) - managed_bl_modifiers: dict[str, _managed_objs.ManagedBLModifier] = bl_cache.BLField( - {} - ) - - @bl_cache.cached_bl_property( - persist=False - ) ## Disable broken ManagedObj union DECODER + @bl_cache.cached_bl_property(persist=True) def managed_objs(self) -> dict[str, _managed_objs.ManagedObj]: """Access the managed objects defined on this node. Persistent cache ensures that the managed objects are only created on first access, even across file reloads. """ - if self.managed_obj_defs: - if not ( - managed_objs := ( - self.managed_bl_meshes - | self.managed_bl_images - | self.managed_bl_modifiers - ) - ): - return { - mobj_name: mobj_def.mk(mobj_def.name_prefix + self.sim_node_name) - for mobj_name, mobj_def in self.managed_obj_defs.items() - } - return managed_objs + if self.managed_obj_types: + return { + mobj_name: mobj_type(self.sim_node_name) + for mobj_name, mobj_type in self.managed_obj_types.items() + } return {} @@ -564,7 +545,7 @@ class MaxwellSimNode(bpy.types.Node): #################### @bl_cache.keyed_cache( exclude={'self', 'optional'}, - serialize={'unit_system'}, + encode={'unit_system'}, ) def _compute_input( self, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index 2b40173..64193dd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -1,17 +1,44 @@ +import abc import functools import typing as typ import bpy +import pydantic as pyd import sympy as sp -import sympy.physics.units as spu import typing_extensions as typx +from .....utils import serialize from ....utils import logger from .. import contracts as ct +from ..socket_types import SocketType log = logger.get(__name__) +#################### +# - SocketDef +#################### +class SocketDef(pyd.BaseModel, abc.ABC): + socket_type: SocketType + + @abc.abstractmethod + def init(self, bl_socket: bpy.types.NodeSocket) -> None: + """Initializes a real Blender node socket from this socket definition.""" + + #################### + # - Serialization + #################### + def dump_as_msgspec(self) -> serialize.NaiveRepresentation: + return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()] + + @staticmethod + def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self: + return SocketDef.__subclasses__[obj[1]](**obj[2]) + + +#################### +# - SocketDef +#################### class MaxwellSimSocket(bpy.types.NodeSocket): # Fundamentals socket_type: ct.SocketType diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index 0d580a2..91ce27e 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -5,6 +5,8 @@ import typing as typ import sympy as sp import sympy.physics.units as spu +SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.Quantity + #################### # - Useful Methods diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py index 6e09dc2..5a9c3ec 100644 --- a/src/blender_maxwell/utils/serialize.py +++ b/src/blender_maxwell/utils/serialize.py @@ -1,57 +1,131 @@ +""" + +Attributes: + NaiveEncodableType: + See for details. +""" + +import dataclasses +import datetime as dt +import decimal +import enum import functools import typing as typ +import uuid import msgspec import sympy as sp -import sympy.physics.units as spu from . import extra_sympy_units as spux from . import logger log = logger.get(__name__) -EncodableValue: typ.TypeAlias = typ.Any ## msgspec-compatible +#################### +# - Serialization Types +#################### +NaivelyEncodableType: typ.TypeAlias = ( + None + | bool + | int + | float + | str + | bytes + | bytearray + ## NO SUPPORT: + # | memoryview + | tuple + | list + | dict + | set + | frozenset + ## NO SUPPORT: + # | typ.Literal + | typ.Collection + ## NO SUPPORT: + # | typ.Sequence ## -> list + # | typ.MutableSequence ## -> list + # | typ.AbstractSet ## -> set + # | typ.MutableSet ## -> set + # | typ.Mapping ## -> dict + # | typ.MutableMapping ## -> dict + | typ.TypedDict + | typ.NamedTuple + | dt.datetime + | dt.date + | dt.time + | dt.timedelta + | uuid.UUID + | decimal.Decimal + ## NO SUPPORT: + # | enum.Enum + | enum.IntEnum + | enum.Flag + | enum.IntFlag + | dataclasses.dataclass + | typ.Optional + | typ.Union + | typ.NewType + | typ.TypeAlias + | typ.TypeAliasType + | typ.Generic + | typ.TypeVar + | typ.Final + | msgspec.Raw + ## NO SUPPORT: + # | msgspec.UNSET +) +_NaivelyEncodableTypeSet = frozenset(typ.get_args(NaivelyEncodableType)) + + +class TypeID(enum.StrEnum): + Complex: str = '!type=complex' + SympyType: str = '!type=sympytype' + SocketDef: str = '!type=socketdef' + ManagedObj: str = '!type=managedobj' + + +NaiveRepresentation: typ.TypeAlias = list[TypeID, str | None, typ.Any] + + +def is_representation(obj: NaivelyEncodableType) -> bool: + return isinstance(obj, list) and obj[0] in set(TypeID) and len(obj) == 3 # noqa: PLR2004 + #################### -# - (De)Serialization +# - Serialization Hooks #################### -EncodedComplex: typ.TypeAlias = tuple[float, float] | list[float, float] -EncodedSympy: typ.TypeAlias = str -EncodedManagedObj: typ.TypeAlias = tuple[str, str] | list[str, str] -EncodedPydanticModel: typ.TypeAlias = tuple[str, str] | list[str, str] - - -def _enc_hook(obj: typ.Any) -> EncodableValue: +def _enc_hook(obj: typ.Any) -> NaivelyEncodableType: """Translates types not natively supported by `msgspec`, to an encodable form supported by `msgspec`. Parameters: - obj: The object of arbitrary type to transform into an encodable value. + obj: The object of arbitrary type to transform into an encodable value. Returns: - A value encodable by `msgspec`. + A value encodable by `msgspec`. Raises: - NotImplementedError: When the type transformation hasn't been implemented. + NotImplementedError: When the type transformation hasn't been implemented. """ if isinstance(obj, complex): - return (obj.real, obj.imag) - if isinstance(obj, sp.Basic | sp.MatrixBase | sp.Expr | spu.Quantity): - return sp.srepr(obj) - if isinstance(obj, managed_objs.ManagedObj): - return (obj.name, obj.__class__.__name__) - if isinstance(obj, ct.schemas.SocketDef): - return (obj.model_dump(), obj.__class__.__name__) + return ['!type=complex', None, (obj.real, obj.imag)] + + if isinstance(obj, spux.SympyType): + return ['!type=sympytype', None, sp.srepr(obj)] + + if hasattr(obj, 'dump_as_msgspec'): + return obj.dump_as_msgspec() msg = f'Can\'t encode "{obj}" of type {type(obj)}' raise NotImplementedError(msg) -def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any: +def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any: """Translates the `msgspec`-encoded form of an object back to its true form. Parameters: _type: The type to transform the `msgspec`-encoded object back into. - obj: The encoded object of to transform back into an encodable value. + obj: The naively decoded object to transform back into its actual type. Returns: A value encodable by `msgspec`. @@ -59,72 +133,31 @@ def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any: Raises: NotImplementedError: When the type transformation hasn't been implemented. """ - if _type is complex and isinstance(obj, EncodedComplex): - return complex(obj[0], obj[1]) - if ( - _type is sp.Basic - and isinstance(obj, EncodedSympy) - or _type is sp.Expr - and isinstance(obj, EncodedSympy) - or _type is sp.MatrixBase - and isinstance(obj, EncodedSympy) - or _type is spu.Quantity - and isinstance(obj, EncodedSympy) + if _type is complex or (is_representation(obj) and obj[0] == TypeID.Complex): + obj_value = obj[2] + return complex(obj_value[0], obj_value[1]) + + if _type in typ.get_args(spux.SympyType) or ( + is_representation(obj) and obj[0] == TypeID.SympyType ): - return sp.sympify(obj).subs(spux.ALL_UNIT_SYMBOLS) - if ( - _type is managed_objs.ManagedBLMesh - and isinstance(obj, EncodedManagedObj) - or _type is managed_objs.ManagedBLImage - and isinstance(obj, EncodedManagedObj) - or _type is managed_objs.ManagedBLModifier - and isinstance(obj, EncodedManagedObj) - ): - return { - 'ManagedBLMesh': managed_objs.ManagedBLMesh, - 'ManagedBLImage': managed_objs.ManagedBLImage, - 'ManagedBLModifier': managed_objs.ManagedBLModifier, - }[obj[1]](obj[0]) - if _type is ct.schemas.SocketDef: - return getattr(sockets, obj[1])(**obj[0]) + obj_value = obj[2] + return sp.sympify(obj_value).subs(spux.ALL_UNIT_SYMBOLS) + + if hasattr(obj, 'parse_as_msgspec'): + return _type.parse_as_msgspec(obj) msg = f'Can\'t decode "{obj}" to type {type(obj)}' raise NotImplementedError(msg) -ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic') - -_DECODERS: dict[type, msgspec.json.Decoder] = { - complex: msgspec.json.Decoder(type=complex, dec_hook=_dec_hook), - sp.Basic: msgspec.json.Decoder(type=sp.Basic, dec_hook=_dec_hook), - sp.Expr: msgspec.json.Decoder(type=sp.Expr, dec_hook=_dec_hook), - sp.MatrixBase: msgspec.json.Decoder(type=sp.MatrixBase, dec_hook=_dec_hook), - spu.Quantity: msgspec.json.Decoder(type=spu.Quantity, dec_hook=_dec_hook), - managed_objs.ManagedBLMesh: msgspec.json.Decoder( - type=managed_objs.ManagedBLMesh, - dec_hook=_dec_hook, - ), - managed_objs.ManagedBLImage: msgspec.json.Decoder( - type=managed_objs.ManagedBLImage, - dec_hook=_dec_hook, - ), - managed_objs.ManagedBLModifier: msgspec.json.Decoder( - type=managed_objs.ManagedBLModifier, - dec_hook=_dec_hook, - ), - # managed_objs.ManagedObj: msgspec.json.Decoder( - # type=managed_objs.ManagedObj, dec_hook=_dec_hook - # ), ## Doesn't work b/c unions are not explicit - ct.schemas.SocketDef: msgspec.json.Decoder( - type=ct.schemas.SocketDef, - dec_hook=_dec_hook, - ), -} -_DECODER_FALLBACK: msgspec.json.Decoder = msgspec.json.Decoder(dec_hook=_dec_hook) +#################### +# - Global Encoders / Decoders +#################### +_ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic') @functools.cache -def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802 +def _DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802 """Retrieve a suitable `msgspec.json.Decoder` by-type. Parameters: @@ -133,21 +166,15 @@ def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802 Returns: A suitable decoder. """ - if (decoder := _DECODERS.get(_type)) is not None: - return decoder - - return _DECODER_FALLBACK + return msgspec.json.Decoder(type=_type, dec_hook=_dec_hook) -def decode_any(_type: type, obj: str) -> typ.Any: - naive_decode = DECODER(_type).decode(obj) - if _type == dict[str, ct.schemas.SocketDef]: - return { - socket_name: getattr(sockets, socket_def_list[1])(**socket_def_list[0]) - for socket_name, socket_def_list in naive_decode.items() - } +#################### +# - Encoder / Decoder Functions +#################### +def encode(obj: typ.Any) -> bytes: + return _ENCODER.encode(obj) - log.critical( - 'Naive Decode of "%s" to "%s" (%s)', str(obj), str(naive_decode), str(_type) - ) - return naive_decode + +def decode(_type: type, obj: str | bytes) -> typ.Any: + return _DECODER(_type).decode(obj)