refactor: Revamped serialization (non-working)

main
Sofus Albert Høgsbro Rose 2024-04-15 17:43:06 +02:00
parent 3def85e24f
commit 4f6bd8e990
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
14 changed files with 276 additions and 192 deletions

View File

@ -547,9 +547,9 @@ We need support for arbitrary objects, but still backed by the persistance seman
### Parallel Features ### Parallel Features
- [x] Move serialization work to a `utils`. - [x] Move serialization work to a `utils`.
- [ ] Also make ENCODER a function that can shortcut the easy cases. - [x] Also make ENCODER a function that can shortcut the easy cases.
- [ ] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions. - [x] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions.
- [ ] Define a superclass for `SocketDef` and make everyone inherit from it - [x] Define a superclass for `SocketDef` and make everyone inherit from it
- [ ] Collect with a `BL_SOCKET_DEFS` object, instead of manually from `__init__.py`s - [ ] Collect with a `BL_SOCKET_DEFS` object, instead of manually from `__init__.py`s
- [ ] Add support for `.msgspec_*()` methods, so that we remove the dependency on sockets from the serialization module. - [ ] Add support for `.msgspec_*()` methods, so that we remove the dependency on sockets from the serialization module.

View File

@ -65,7 +65,7 @@ class KeyedCache:
self, self,
func: typ.Callable, func: typ.Callable,
exclude: set[str], exclude: set[str],
serialize: set[str], encode: set[str],
): ):
# Function Information # Function Information
self.func: typ.Callable = func self.func: typ.Callable = func
@ -74,7 +74,7 @@ class KeyedCache:
# Arg -> Key Information # Arg -> Key Information
self.exclude: set[str] = exclude self.exclude: set[str] = exclude
self.include: set[str] = set(self.func_sig.parameters.keys()) - exclude self.include: set[str] = set(self.func_sig.parameters.keys()) - exclude
self.serialize: set[str] = serialize self.encode: set[str] = encode
# Cache Information # Cache Information
self.key_schema: tuple[str, ...] = tuple( self.key_schema: tuple[str, ...] = tuple(
@ -102,8 +102,8 @@ class KeyedCache:
[ [
( (
arg_value arg_value
if arg_name not in self.serialize if arg_name not in self.encode
else ENCODER.encode(arg_value) else serialize.encode(arg_value)
) )
for arg_name, arg_value in arguments.items() for arg_name, arg_value in arguments.items()
if arg_name in self.include if arg_name in self.include
@ -153,8 +153,8 @@ class KeyedCache:
# Compute Keys to Invalidate # Compute Keys to Invalidate
arguments_hashable = { arguments_hashable = {
arg_name: ENCODER.encode(arg_value) arg_name: serialize.encode(arg_value)
if arg_name in self.serialize and arg_name not in wildcard_arguments if arg_name in self.encode and arg_name not in wildcard_arguments
else arg_value else arg_value
for arg_name, arg_value in arguments.items() for arg_name, arg_value in arguments.items()
} }
@ -168,12 +168,12 @@ class KeyedCache:
cache.pop(key) cache.pop(key)
def keyed_cache(exclude: set[str], serialize: set[str] = frozenset()) -> typ.Callable: def keyed_cache(exclude: set[str], encode: set[str] = frozenset()) -> typ.Callable:
def decorator(func: typ.Callable) -> typ.Callable: def decorator(func: typ.Callable) -> typ.Callable:
return KeyedCache( return KeyedCache(
func, func,
exclude=exclude, exclude=exclude,
serialize=serialize, encode=encode,
) )
return decorator return decorator
@ -219,9 +219,6 @@ class CachedBLProperty:
inspect.signature(getter_method).return_annotation if persist else None inspect.signature(getter_method).return_annotation if persist else None
) )
# Check Non-Empty Type Annotation
## For now, just presume that all types can be encoded/decoded.
# Check Non-Empty Type Annotation # Check Non-Empty Type Annotation
## For now, just presume that all types can be encoded/decoded. ## For now, just presume that all types can be encoded/decoded.
if self._type is not None and self._type is inspect.Signature.empty: if self._type is not None and self._type is inspect.Signature.empty:
@ -283,7 +280,7 @@ class CachedBLProperty:
self._persist self._persist
and (encoded_value := getattr(bl_instance, self._bl_prop_name)) != '' and (encoded_value := getattr(bl_instance, self._bl_prop_name)) != ''
): ):
value = decode_any(self._type, encoded_value) value = serialize.decode(self._type, encoded_value)
cache_nopersist[self._bl_prop_name] = value cache_nopersist[self._bl_prop_name] = value
return value return value
@ -294,7 +291,7 @@ class CachedBLProperty:
cache_nopersist[self._bl_prop_name] = value cache_nopersist[self._bl_prop_name] = value
if self._persist: if self._persist:
setattr( setattr(
bl_instance, self._bl_prop_name, ENCODER.encode(value).decode('utf-8') bl_instance, self._bl_prop_name, serialize.encode(value).decode('utf-8')
) )
return value return value
@ -466,7 +463,7 @@ class BLField:
raise TypeError(msg) raise TypeError(msg)
# Define Blender Property (w/Update Sync) # Define Blender Property (w/Update Sync)
encoded_default_value = ENCODER.encode(self._default_value).decode('utf-8') encoded_default_value = serialize.encode(self._default_value).decode('utf-8')
log.debug( log.debug(
'%s set to StringProperty w/default "%s" and no_update="%s"', '%s set to StringProperty w/default "%s" and no_update="%s"',
bl_attr_name, bl_attr_name,
@ -487,14 +484,14 @@ class BLField:
## 2. Retrieve bpy.props.StringProperty string. ## 2. Retrieve bpy.props.StringProperty string.
## 3. Decode using annotated type. ## 3. Decode using annotated type.
def getter(_self: BLInstance) -> AttrType: def getter(_self: BLInstance) -> AttrType:
return decode_any(AttrType, getattr(_self, bl_attr_name)) return serialize.decode(AttrType, getattr(_self, bl_attr_name))
## Setter: ## Setter:
## 1. Initialize bpy.props.StringProperty to Default (if undefined). ## 1. Initialize bpy.props.StringProperty to Default (if undefined).
## 3. Encode value (implicitly using the annotated type). ## 3. Encode value (implicitly using the annotated type).
## 2. Set bpy.props.StringProperty string. ## 2. Set bpy.props.StringProperty string.
def setter(_self: BLInstance, value: AttrType) -> None: def setter(_self: BLInstance, value: AttrType) -> None:
encoded_value = ENCODER.encode(value).decode('utf-8') encoded_value = serialize.encode(value).decode('utf-8')
log.debug( log.debug(
'Writing BLField attr "%s" w/encoded value: %s', 'Writing BLField attr "%s" w/encoded value: %s',
bl_attr_name, bl_attr_name,

View File

@ -1,22 +0,0 @@
import typing as typ
from ..bl import ManagedObjName
from ..managed_obj_type import ManagedObjType
class ManagedObj(typ.Protocol):
managed_obj_type: ManagedObjType
def __init__(
self,
name: ManagedObjName,
): ...
@property
def name(self) -> str: ...
@name.setter
def name(self, value: str): ...
def free(self): ...
def bl_select(self): ...

View File

@ -1,10 +0,0 @@
import typing as typ
import pydantic as pyd
from .managed_obj import ManagedObj
class ManagedObjDef(pyd.BaseModel):
mk: typ.Callable[[str], ManagedObj]
name_prefix: str = ''

View File

@ -1,12 +1,26 @@
import abc
import typing as typ import typing as typ
import bpy import bpy
import pydantic as pyd
from .....utils import serialize
from ..socket_types import SocketType from ..socket_types import SocketType
@typ.runtime_checkable class SocketDef(pyd.BaseModel, abc.ABC):
class SocketDef(typ.Protocol):
socket_type: SocketType socket_type: SocketType
def init(self, bl_socket: bpy.types.NodeSocket) -> None: ... @abc.abstractmethod
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Initializes a real Blender node socket from this socket definition."""
####################
# - Serialization
####################
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()]
@staticmethod
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
return SocketDef.__subclasses__[obj[1]](**obj[2])

View File

@ -1,5 +1,7 @@
import typing as typ import typing as typ
from .base import ManagedObj
# from .managed_bl_empty import ManagedBLEmpty # from .managed_bl_empty import ManagedBLEmpty
from .managed_bl_image import ManagedBLImage from .managed_bl_image import ManagedBLImage
@ -10,9 +12,8 @@ from .managed_bl_mesh import ManagedBLMesh
# from .managed_bl_volume import ManagedBLVolume # from .managed_bl_volume import ManagedBLVolume
from .managed_bl_modifier import ManagedBLModifier from .managed_bl_modifier import ManagedBLModifier
ManagedObj: typ.TypeAlias = ManagedBLImage | ManagedBLMesh | ManagedBLModifier
__all__ = [ __all__ = [
'ManagedObj',
#'ManagedBLEmpty', #'ManagedBLEmpty',
'ManagedBLImage', 'ManagedBLImage',
#'ManagedBLCollection', #'ManagedBLCollection',

View File

@ -0,0 +1,58 @@
import abc
import typing as typ
from ....utils import serialize
from .. import contracts as ct
class ManagedObj(abc.ABC):
managed_obj_type: ct.ManagedObjType
@abc.abstractmethod
def __init__(
self,
name: ct.ManagedObjName,
):
"""Initializes the managed object with a unique name."""
####################
# - Properties
####################
@property
@abc.abstractmethod
def name(self) -> str:
"""Retrieve the name of the managed object."""
@name.setter
@abc.abstractmethod
def name(self, value: str) -> None:
"""Retrieve the name of the managed object."""
####################
# - Methods
####################
@abc.abstractmethod
def free(self) -> None:
"""Cleanup the resources managed by the managed object."""
@abc.abstractmethod
def bl_select(self) -> None:
"""Select the managed object in Blender, if such an operation makes sense."""
@abc.abstractmethod
def hide_preview(self) -> None:
"""Select the managed object in Blender, if such an operation makes sense."""
####################
# - Serialization
####################
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
return [
serialize.TypeID.ManagedObj,
self.__class__.__name__,
(self.managed_obj_type, self.name),
]
@staticmethod
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
return ManagedObj.__subclasses__[obj[1]](**obj[2])

View File

@ -12,6 +12,7 @@ import typing_extensions as typx
from ....utils import logger from ....utils import logger
from .. import contracts as ct from .. import contracts as ct
from . import base
log = logger.get(__name__) log = logger.get(__name__)
@ -76,7 +77,7 @@ def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None):
return rgba_image_from_xyzf__grayscale(xyz_freq) return rgba_image_from_xyzf__grayscale(xyz_freq)
class ManagedBLImage(ct.schemas.ManagedObj): class ManagedBLImage(base.ManagedObj):
managed_obj_type = ct.ManagedObjType.ManagedBLImage managed_obj_type = ct.ManagedObjType.ManagedBLImage
_bl_image_name: str _bl_image_name: str
@ -181,6 +182,9 @@ class ManagedBLImage(ct.schemas.ManagedObj):
if bl_image := bpy.data.images.get(self.name): if bl_image := bpy.data.images.get(self.name):
self.preview_space.image = bl_image self.preview_space.image = bl_image
def hide_preview(self) -> None:
self.preview_space.image = None
#################### ####################
# - Image Geometry # - Image Geometry
#################### ####################
@ -269,12 +273,12 @@ class ManagedBLImage(ct.schemas.ManagedObj):
) )
# log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start) # log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start)
#log.debug( # log.debug(
# 'Creating MPL Axes (aspect=%f, width=%f, height=%f)', # 'Creating MPL Axes (aspect=%f, width=%f, height=%f)',
# aspect_ratio, # aspect_ratio,
# _width_inches, # _width_inches,
# _height_inches, # _height_inches,
#) # )
# Create MPL Figure, Axes, and Compute Figure Geometry # Create MPL Figure, Axes, and Compute Figure Geometry
fig, ax = plt.subplots( fig, ax = plt.subplots(
figsize=[_width_inches, _height_inches], figsize=[_width_inches, _height_inches],

View File

@ -7,6 +7,7 @@ import numpy as np
from ....utils import logger from ....utils import logger
from .. import contracts as ct from .. import contracts as ct
from .managed_bl_collection import managed_collection, preview_collection from .managed_bl_collection import managed_collection, preview_collection
from . import base
log = logger.get(__name__) log = logger.get(__name__)
@ -14,7 +15,7 @@ log = logger.get(__name__)
#################### ####################
# - BLMesh # - BLMesh
#################### ####################
class ManagedBLMesh(ct.schemas.ManagedObj): class ManagedBLMesh(base.ManagedObj):
managed_obj_type = ct.ManagedObjType.ManagedBLMesh managed_obj_type = ct.ManagedObjType.ManagedBLMesh
_bl_object_name: str | None = None _bl_object_name: str | None = None

View File

@ -8,6 +8,7 @@ import typing_extensions as typx
from ....utils import analyze_geonodes, logger from ....utils import analyze_geonodes, logger
from .. import bl_socket_map from .. import bl_socket_map
from .. import contracts as ct from .. import contracts as ct
from . import base
log = logger.get(__name__) log = logger.get(__name__)
@ -160,7 +161,7 @@ def write_modifier(
#################### ####################
# - ManagedObj # - ManagedObj
#################### ####################
class ManagedBLModifier(ct.schemas.ManagedObj): class ManagedBLModifier(base.ManagedObj):
managed_obj_type = ct.ManagedObjType.ManagedBLModifier managed_obj_type = ct.ManagedObjType.ManagedBLModifier
_modifier_name: str | None = None _modifier_name: str | None = None
@ -185,6 +186,9 @@ class ManagedBLModifier(ct.schemas.ManagedObj):
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
def bl_select(self) -> None: pass
def hide_preview(self) -> None: pass
#################### ####################
# - Deallocation # - Deallocation
#################### ####################

View File

@ -55,8 +55,8 @@ class MaxwellSimNode(bpy.types.Node):
presets: typ.ClassVar = MappingProxyType({}) presets: typ.ClassVar = MappingProxyType({})
# Managed Objects # Managed Objects
managed_obj_defs: typ.ClassVar[ managed_obj_types: typ.ClassVar[
dict[ct.ManagedObjName, ct.schemas.ManagedObjDef] dict[ct.ManagedObjName, type[_managed_objs.ManagedObj]]
] = MappingProxyType({}) ] = MappingProxyType({})
#################### ####################
@ -222,7 +222,7 @@ class MaxwellSimNode(bpy.types.Node):
#################### ####################
@events.on_value_changed( @events.on_value_changed(
prop_name='sim_node_name', prop_name='sim_node_name',
props={'sim_node_name', 'managed_objs', 'managed_obj_defs'}, props={'sim_node_name', 'managed_objs', 'managed_obj_types'},
) )
def _on_sim_node_name_changed(self, props: dict): def _on_sim_node_name_changed(self, props: dict):
log.info( log.info(
@ -233,9 +233,8 @@ class MaxwellSimNode(bpy.types.Node):
) )
# Set Name of Managed Objects # Set Name of Managed Objects
for mobj_id, mobj in props['managed_objs'].items(): for mobj in props['managed_objs'].values():
mobj_def = props['managed_obj_defs'][mobj_id] mobj.name = props['sim_node_name']
mobj.name = mobj_def.name_prefix + props['sim_node_name']
@events.on_value_changed(prop_name='active_socket_set') @events.on_value_changed(prop_name='active_socket_set')
def _on_socket_set_changed(self): def _on_socket_set_changed(self):
@ -282,9 +281,7 @@ class MaxwellSimNode(bpy.types.Node):
def _on_preview_changed(self, props): def _on_preview_changed(self, props):
if not props['preview_active']: if not props['preview_active']:
for mobj in self.managed_objs.values(): for mobj in self.managed_objs.values():
if isinstance(mobj, _managed_objs.ManagedBLMesh): mobj.hide_preview()
## TODO: This is a Workaround
mobj.hide_preview()
@events.on_enable_lock() @events.on_enable_lock()
def _on_enabled_lock(self): def _on_enabled_lock(self):
@ -460,33 +457,17 @@ class MaxwellSimNode(bpy.types.Node):
#################### ####################
# - Managed Objects # - Managed Objects
#################### ####################
managed_bl_meshes: dict[str, _managed_objs.ManagedBLMesh] = bl_cache.BLField({}) @bl_cache.cached_bl_property(persist=True)
managed_bl_images: dict[str, _managed_objs.ManagedBLImage] = bl_cache.BLField({})
managed_bl_modifiers: dict[str, _managed_objs.ManagedBLModifier] = bl_cache.BLField(
{}
)
@bl_cache.cached_bl_property(
persist=False
) ## Disable broken ManagedObj union DECODER
def managed_objs(self) -> dict[str, _managed_objs.ManagedObj]: def managed_objs(self) -> dict[str, _managed_objs.ManagedObj]:
"""Access the managed objects defined on this node. """Access the managed objects defined on this node.
Persistent cache ensures that the managed objects are only created on first access, even across file reloads. Persistent cache ensures that the managed objects are only created on first access, even across file reloads.
""" """
if self.managed_obj_defs: if self.managed_obj_types:
if not ( return {
managed_objs := ( mobj_name: mobj_type(self.sim_node_name)
self.managed_bl_meshes for mobj_name, mobj_type in self.managed_obj_types.items()
| self.managed_bl_images }
| self.managed_bl_modifiers
)
):
return {
mobj_name: mobj_def.mk(mobj_def.name_prefix + self.sim_node_name)
for mobj_name, mobj_def in self.managed_obj_defs.items()
}
return managed_objs
return {} return {}
@ -564,7 +545,7 @@ class MaxwellSimNode(bpy.types.Node):
#################### ####################
@bl_cache.keyed_cache( @bl_cache.keyed_cache(
exclude={'self', 'optional'}, exclude={'self', 'optional'},
serialize={'unit_system'}, encode={'unit_system'},
) )
def _compute_input( def _compute_input(
self, self,

View File

@ -1,17 +1,44 @@
import abc
import functools import functools
import typing as typ import typing as typ
import bpy import bpy
import pydantic as pyd
import sympy as sp import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx import typing_extensions as typx
from .....utils import serialize
from ....utils import logger from ....utils import logger
from .. import contracts as ct from .. import contracts as ct
from ..socket_types import SocketType
log = logger.get(__name__) log = logger.get(__name__)
####################
# - SocketDef
####################
class SocketDef(pyd.BaseModel, abc.ABC):
socket_type: SocketType
@abc.abstractmethod
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
"""Initializes a real Blender node socket from this socket definition."""
####################
# - Serialization
####################
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()]
@staticmethod
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
return SocketDef.__subclasses__[obj[1]](**obj[2])
####################
# - SocketDef
####################
class MaxwellSimSocket(bpy.types.NodeSocket): class MaxwellSimSocket(bpy.types.NodeSocket):
# Fundamentals # Fundamentals
socket_type: ct.SocketType socket_type: ct.SocketType

View File

@ -5,6 +5,8 @@ import typing as typ
import sympy as sp import sympy as sp
import sympy.physics.units as spu import sympy.physics.units as spu
SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.Quantity
#################### ####################
# - Useful Methods # - Useful Methods

View File

@ -1,57 +1,131 @@
"""
Attributes:
NaiveEncodableType:
See <https://jcristharif.com/msgspec/supported-types.html> for details.
"""
import dataclasses
import datetime as dt
import decimal
import enum
import functools import functools
import typing as typ import typing as typ
import uuid
import msgspec import msgspec
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from . import extra_sympy_units as spux from . import extra_sympy_units as spux
from . import logger from . import logger
log = logger.get(__name__) log = logger.get(__name__)
EncodableValue: typ.TypeAlias = typ.Any ## msgspec-compatible ####################
# - Serialization Types
####################
NaivelyEncodableType: typ.TypeAlias = (
None
| bool
| int
| float
| str
| bytes
| bytearray
## NO SUPPORT:
# | memoryview
| tuple
| list
| dict
| set
| frozenset
## NO SUPPORT:
# | typ.Literal
| typ.Collection
## NO SUPPORT:
# | typ.Sequence ## -> list
# | typ.MutableSequence ## -> list
# | typ.AbstractSet ## -> set
# | typ.MutableSet ## -> set
# | typ.Mapping ## -> dict
# | typ.MutableMapping ## -> dict
| typ.TypedDict
| typ.NamedTuple
| dt.datetime
| dt.date
| dt.time
| dt.timedelta
| uuid.UUID
| decimal.Decimal
## NO SUPPORT:
# | enum.Enum
| enum.IntEnum
| enum.Flag
| enum.IntFlag
| dataclasses.dataclass
| typ.Optional
| typ.Union
| typ.NewType
| typ.TypeAlias
| typ.TypeAliasType
| typ.Generic
| typ.TypeVar
| typ.Final
| msgspec.Raw
## NO SUPPORT:
# | msgspec.UNSET
)
_NaivelyEncodableTypeSet = frozenset(typ.get_args(NaivelyEncodableType))
class TypeID(enum.StrEnum):
Complex: str = '!type=complex'
SympyType: str = '!type=sympytype'
SocketDef: str = '!type=socketdef'
ManagedObj: str = '!type=managedobj'
NaiveRepresentation: typ.TypeAlias = list[TypeID, str | None, typ.Any]
def is_representation(obj: NaivelyEncodableType) -> bool:
return isinstance(obj, list) and obj[0] in set(TypeID) and len(obj) == 3 # noqa: PLR2004
#################### ####################
# - (De)Serialization # - Serialization Hooks
#################### ####################
EncodedComplex: typ.TypeAlias = tuple[float, float] | list[float, float] def _enc_hook(obj: typ.Any) -> NaivelyEncodableType:
EncodedSympy: typ.TypeAlias = str
EncodedManagedObj: typ.TypeAlias = tuple[str, str] | list[str, str]
EncodedPydanticModel: typ.TypeAlias = tuple[str, str] | list[str, str]
def _enc_hook(obj: typ.Any) -> EncodableValue:
"""Translates types not natively supported by `msgspec`, to an encodable form supported by `msgspec`. """Translates types not natively supported by `msgspec`, to an encodable form supported by `msgspec`.
Parameters: Parameters:
obj: The object of arbitrary type to transform into an encodable value. obj: The object of arbitrary type to transform into an encodable value.
Returns: Returns:
A value encodable by `msgspec`. A value encodable by `msgspec`.
Raises: Raises:
NotImplementedError: When the type transformation hasn't been implemented. NotImplementedError: When the type transformation hasn't been implemented.
""" """
if isinstance(obj, complex): if isinstance(obj, complex):
return (obj.real, obj.imag) return ['!type=complex', None, (obj.real, obj.imag)]
if isinstance(obj, sp.Basic | sp.MatrixBase | sp.Expr | spu.Quantity):
return sp.srepr(obj) if isinstance(obj, spux.SympyType):
if isinstance(obj, managed_objs.ManagedObj): return ['!type=sympytype', None, sp.srepr(obj)]
return (obj.name, obj.__class__.__name__)
if isinstance(obj, ct.schemas.SocketDef): if hasattr(obj, 'dump_as_msgspec'):
return (obj.model_dump(), obj.__class__.__name__) return obj.dump_as_msgspec()
msg = f'Can\'t encode "{obj}" of type {type(obj)}' msg = f'Can\'t encode "{obj}" of type {type(obj)}'
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any: def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
"""Translates the `msgspec`-encoded form of an object back to its true form. """Translates the `msgspec`-encoded form of an object back to its true form.
Parameters: Parameters:
_type: The type to transform the `msgspec`-encoded object back into. _type: The type to transform the `msgspec`-encoded object back into.
obj: The encoded object of to transform back into an encodable value. obj: The naively decoded object to transform back into its actual type.
Returns: Returns:
A value encodable by `msgspec`. A value encodable by `msgspec`.
@ -59,72 +133,31 @@ def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any:
Raises: Raises:
NotImplementedError: When the type transformation hasn't been implemented. NotImplementedError: When the type transformation hasn't been implemented.
""" """
if _type is complex and isinstance(obj, EncodedComplex): if _type is complex or (is_representation(obj) and obj[0] == TypeID.Complex):
return complex(obj[0], obj[1]) obj_value = obj[2]
if ( return complex(obj_value[0], obj_value[1])
_type is sp.Basic
and isinstance(obj, EncodedSympy) if _type in typ.get_args(spux.SympyType) or (
or _type is sp.Expr is_representation(obj) and obj[0] == TypeID.SympyType
and isinstance(obj, EncodedSympy)
or _type is sp.MatrixBase
and isinstance(obj, EncodedSympy)
or _type is spu.Quantity
and isinstance(obj, EncodedSympy)
): ):
return sp.sympify(obj).subs(spux.ALL_UNIT_SYMBOLS) obj_value = obj[2]
if ( return sp.sympify(obj_value).subs(spux.ALL_UNIT_SYMBOLS)
_type is managed_objs.ManagedBLMesh
and isinstance(obj, EncodedManagedObj) if hasattr(obj, 'parse_as_msgspec'):
or _type is managed_objs.ManagedBLImage return _type.parse_as_msgspec(obj)
and isinstance(obj, EncodedManagedObj)
or _type is managed_objs.ManagedBLModifier
and isinstance(obj, EncodedManagedObj)
):
return {
'ManagedBLMesh': managed_objs.ManagedBLMesh,
'ManagedBLImage': managed_objs.ManagedBLImage,
'ManagedBLModifier': managed_objs.ManagedBLModifier,
}[obj[1]](obj[0])
if _type is ct.schemas.SocketDef:
return getattr(sockets, obj[1])(**obj[0])
msg = f'Can\'t decode "{obj}" to type {type(obj)}' msg = f'Can\'t decode "{obj}" to type {type(obj)}'
raise NotImplementedError(msg) raise NotImplementedError(msg)
ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic') ####################
# - Global Encoders / Decoders
_DECODERS: dict[type, msgspec.json.Decoder] = { ####################
complex: msgspec.json.Decoder(type=complex, dec_hook=_dec_hook), _ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic')
sp.Basic: msgspec.json.Decoder(type=sp.Basic, dec_hook=_dec_hook),
sp.Expr: msgspec.json.Decoder(type=sp.Expr, dec_hook=_dec_hook),
sp.MatrixBase: msgspec.json.Decoder(type=sp.MatrixBase, dec_hook=_dec_hook),
spu.Quantity: msgspec.json.Decoder(type=spu.Quantity, dec_hook=_dec_hook),
managed_objs.ManagedBLMesh: msgspec.json.Decoder(
type=managed_objs.ManagedBLMesh,
dec_hook=_dec_hook,
),
managed_objs.ManagedBLImage: msgspec.json.Decoder(
type=managed_objs.ManagedBLImage,
dec_hook=_dec_hook,
),
managed_objs.ManagedBLModifier: msgspec.json.Decoder(
type=managed_objs.ManagedBLModifier,
dec_hook=_dec_hook,
),
# managed_objs.ManagedObj: msgspec.json.Decoder(
# type=managed_objs.ManagedObj, dec_hook=_dec_hook
# ), ## Doesn't work b/c unions are not explicit
ct.schemas.SocketDef: msgspec.json.Decoder(
type=ct.schemas.SocketDef,
dec_hook=_dec_hook,
),
}
_DECODER_FALLBACK: msgspec.json.Decoder = msgspec.json.Decoder(dec_hook=_dec_hook)
@functools.cache @functools.cache
def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802 def _DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802
"""Retrieve a suitable `msgspec.json.Decoder` by-type. """Retrieve a suitable `msgspec.json.Decoder` by-type.
Parameters: Parameters:
@ -133,21 +166,15 @@ def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802
Returns: Returns:
A suitable decoder. A suitable decoder.
""" """
if (decoder := _DECODERS.get(_type)) is not None: return msgspec.json.Decoder(type=_type, dec_hook=_dec_hook)
return decoder
return _DECODER_FALLBACK
def decode_any(_type: type, obj: str) -> typ.Any: ####################
naive_decode = DECODER(_type).decode(obj) # - Encoder / Decoder Functions
if _type == dict[str, ct.schemas.SocketDef]: ####################
return { def encode(obj: typ.Any) -> bytes:
socket_name: getattr(sockets, socket_def_list[1])(**socket_def_list[0]) return _ENCODER.encode(obj)
for socket_name, socket_def_list in naive_decode.items()
}
log.critical(
'Naive Decode of "%s" to "%s" (%s)', str(obj), str(naive_decode), str(_type) def decode(_type: type, obj: str | bytes) -> typ.Any:
) return _DECODER(_type).decode(obj)
return naive_decode