refactor: Revamped serialization (non-working)
parent
3def85e24f
commit
4f6bd8e990
6
TODO.md
6
TODO.md
|
@ -547,9 +547,9 @@ We need support for arbitrary objects, but still backed by the persistance seman
|
|||
|
||||
### Parallel Features
|
||||
- [x] Move serialization work to a `utils`.
|
||||
- [ ] Also make ENCODER a function that can shortcut the easy cases.
|
||||
- [ ] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions.
|
||||
- [ ] Define a superclass for `SocketDef` and make everyone inherit from it
|
||||
- [x] Also make ENCODER a function that can shortcut the easy cases.
|
||||
- [x] For serializeability, let the encoder/decoder be able to make use of an optional `.msgspec_encodable()` and similar decoder respectively, and add support for these in the ENCODER/DECODER functions.
|
||||
- [x] Define a superclass for `SocketDef` and make everyone inherit from it
|
||||
- [ ] Collect with a `BL_SOCKET_DEFS` object, instead of manually from `__init__.py`s
|
||||
- [ ] Add support for `.msgspec_*()` methods, so that we remove the dependency on sockets from the serialization module.
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class KeyedCache:
|
|||
self,
|
||||
func: typ.Callable,
|
||||
exclude: set[str],
|
||||
serialize: set[str],
|
||||
encode: set[str],
|
||||
):
|
||||
# Function Information
|
||||
self.func: typ.Callable = func
|
||||
|
@ -74,7 +74,7 @@ class KeyedCache:
|
|||
# Arg -> Key Information
|
||||
self.exclude: set[str] = exclude
|
||||
self.include: set[str] = set(self.func_sig.parameters.keys()) - exclude
|
||||
self.serialize: set[str] = serialize
|
||||
self.encode: set[str] = encode
|
||||
|
||||
# Cache Information
|
||||
self.key_schema: tuple[str, ...] = tuple(
|
||||
|
@ -102,8 +102,8 @@ class KeyedCache:
|
|||
[
|
||||
(
|
||||
arg_value
|
||||
if arg_name not in self.serialize
|
||||
else ENCODER.encode(arg_value)
|
||||
if arg_name not in self.encode
|
||||
else serialize.encode(arg_value)
|
||||
)
|
||||
for arg_name, arg_value in arguments.items()
|
||||
if arg_name in self.include
|
||||
|
@ -153,8 +153,8 @@ class KeyedCache:
|
|||
|
||||
# Compute Keys to Invalidate
|
||||
arguments_hashable = {
|
||||
arg_name: ENCODER.encode(arg_value)
|
||||
if arg_name in self.serialize and arg_name not in wildcard_arguments
|
||||
arg_name: serialize.encode(arg_value)
|
||||
if arg_name in self.encode and arg_name not in wildcard_arguments
|
||||
else arg_value
|
||||
for arg_name, arg_value in arguments.items()
|
||||
}
|
||||
|
@ -168,12 +168,12 @@ class KeyedCache:
|
|||
cache.pop(key)
|
||||
|
||||
|
||||
def keyed_cache(exclude: set[str], serialize: set[str] = frozenset()) -> typ.Callable:
|
||||
def keyed_cache(exclude: set[str], encode: set[str] = frozenset()) -> typ.Callable:
|
||||
def decorator(func: typ.Callable) -> typ.Callable:
|
||||
return KeyedCache(
|
||||
func,
|
||||
exclude=exclude,
|
||||
serialize=serialize,
|
||||
encode=encode,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
@ -219,9 +219,6 @@ class CachedBLProperty:
|
|||
inspect.signature(getter_method).return_annotation if persist else None
|
||||
)
|
||||
|
||||
# Check Non-Empty Type Annotation
|
||||
## For now, just presume that all types can be encoded/decoded.
|
||||
|
||||
# Check Non-Empty Type Annotation
|
||||
## For now, just presume that all types can be encoded/decoded.
|
||||
if self._type is not None and self._type is inspect.Signature.empty:
|
||||
|
@ -283,7 +280,7 @@ class CachedBLProperty:
|
|||
self._persist
|
||||
and (encoded_value := getattr(bl_instance, self._bl_prop_name)) != ''
|
||||
):
|
||||
value = decode_any(self._type, encoded_value)
|
||||
value = serialize.decode(self._type, encoded_value)
|
||||
cache_nopersist[self._bl_prop_name] = value
|
||||
return value
|
||||
|
||||
|
@ -294,7 +291,7 @@ class CachedBLProperty:
|
|||
cache_nopersist[self._bl_prop_name] = value
|
||||
if self._persist:
|
||||
setattr(
|
||||
bl_instance, self._bl_prop_name, ENCODER.encode(value).decode('utf-8')
|
||||
bl_instance, self._bl_prop_name, serialize.encode(value).decode('utf-8')
|
||||
)
|
||||
return value
|
||||
|
||||
|
@ -466,7 +463,7 @@ class BLField:
|
|||
raise TypeError(msg)
|
||||
|
||||
# Define Blender Property (w/Update Sync)
|
||||
encoded_default_value = ENCODER.encode(self._default_value).decode('utf-8')
|
||||
encoded_default_value = serialize.encode(self._default_value).decode('utf-8')
|
||||
log.debug(
|
||||
'%s set to StringProperty w/default "%s" and no_update="%s"',
|
||||
bl_attr_name,
|
||||
|
@ -487,14 +484,14 @@ class BLField:
|
|||
## 2. Retrieve bpy.props.StringProperty string.
|
||||
## 3. Decode using annotated type.
|
||||
def getter(_self: BLInstance) -> AttrType:
|
||||
return decode_any(AttrType, getattr(_self, bl_attr_name))
|
||||
return serialize.decode(AttrType, getattr(_self, bl_attr_name))
|
||||
|
||||
## Setter:
|
||||
## 1. Initialize bpy.props.StringProperty to Default (if undefined).
|
||||
## 3. Encode value (implicitly using the annotated type).
|
||||
## 2. Set bpy.props.StringProperty string.
|
||||
def setter(_self: BLInstance, value: AttrType) -> None:
|
||||
encoded_value = ENCODER.encode(value).decode('utf-8')
|
||||
encoded_value = serialize.encode(value).decode('utf-8')
|
||||
log.debug(
|
||||
'Writing BLField attr "%s" w/encoded value: %s',
|
||||
bl_attr_name,
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
import typing as typ
|
||||
|
||||
from ..bl import ManagedObjName
|
||||
from ..managed_obj_type import ManagedObjType
|
||||
|
||||
|
||||
class ManagedObj(typ.Protocol):
|
||||
managed_obj_type: ManagedObjType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: ManagedObjName,
|
||||
): ...
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@name.setter
|
||||
def name(self, value: str): ...
|
||||
|
||||
def free(self): ...
|
||||
|
||||
def bl_select(self): ...
|
|
@ -1,10 +0,0 @@
|
|||
import typing as typ
|
||||
|
||||
import pydantic as pyd
|
||||
|
||||
from .managed_obj import ManagedObj
|
||||
|
||||
|
||||
class ManagedObjDef(pyd.BaseModel):
|
||||
mk: typ.Callable[[str], ManagedObj]
|
||||
name_prefix: str = ''
|
|
@ -1,12 +1,26 @@
|
|||
import abc
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import pydantic as pyd
|
||||
|
||||
from .....utils import serialize
|
||||
from ..socket_types import SocketType
|
||||
|
||||
|
||||
@typ.runtime_checkable
|
||||
class SocketDef(typ.Protocol):
|
||||
class SocketDef(pyd.BaseModel, abc.ABC):
|
||||
socket_type: SocketType
|
||||
|
||||
def init(self, bl_socket: bpy.types.NodeSocket) -> None: ...
|
||||
@abc.abstractmethod
|
||||
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
"""Initializes a real Blender node socket from this socket definition."""
|
||||
|
||||
####################
|
||||
# - Serialization
|
||||
####################
|
||||
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
|
||||
return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()]
|
||||
|
||||
@staticmethod
|
||||
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
|
||||
return SocketDef.__subclasses__[obj[1]](**obj[2])
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import typing as typ
|
||||
|
||||
from .base import ManagedObj
|
||||
|
||||
# from .managed_bl_empty import ManagedBLEmpty
|
||||
from .managed_bl_image import ManagedBLImage
|
||||
|
||||
|
@ -10,9 +12,8 @@ from .managed_bl_mesh import ManagedBLMesh
|
|||
# from .managed_bl_volume import ManagedBLVolume
|
||||
from .managed_bl_modifier import ManagedBLModifier
|
||||
|
||||
ManagedObj: typ.TypeAlias = ManagedBLImage | ManagedBLMesh | ManagedBLModifier
|
||||
|
||||
__all__ = [
|
||||
'ManagedObj',
|
||||
#'ManagedBLEmpty',
|
||||
'ManagedBLImage',
|
||||
#'ManagedBLCollection',
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import abc
|
||||
import typing as typ
|
||||
|
||||
from ....utils import serialize
|
||||
from .. import contracts as ct
|
||||
|
||||
|
||||
class ManagedObj(abc.ABC):
|
||||
managed_obj_type: ct.ManagedObjType
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
name: ct.ManagedObjName,
|
||||
):
|
||||
"""Initializes the managed object with a unique name."""
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Retrieve the name of the managed object."""
|
||||
|
||||
@name.setter
|
||||
@abc.abstractmethod
|
||||
def name(self, value: str) -> None:
|
||||
"""Retrieve the name of the managed object."""
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
####################
|
||||
@abc.abstractmethod
|
||||
def free(self) -> None:
|
||||
"""Cleanup the resources managed by the managed object."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def bl_select(self) -> None:
|
||||
"""Select the managed object in Blender, if such an operation makes sense."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def hide_preview(self) -> None:
|
||||
"""Select the managed object in Blender, if such an operation makes sense."""
|
||||
|
||||
####################
|
||||
# - Serialization
|
||||
####################
|
||||
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
|
||||
return [
|
||||
serialize.TypeID.ManagedObj,
|
||||
self.__class__.__name__,
|
||||
(self.managed_obj_type, self.name),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
|
||||
return ManagedObj.__subclasses__[obj[1]](**obj[2])
|
|
@ -12,6 +12,7 @@ import typing_extensions as typx
|
|||
|
||||
from ....utils import logger
|
||||
from .. import contracts as ct
|
||||
from . import base
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -76,7 +77,7 @@ def rgba_image_from_xyzf(xyz_freq, colormap: str | None = None):
|
|||
return rgba_image_from_xyzf__grayscale(xyz_freq)
|
||||
|
||||
|
||||
class ManagedBLImage(ct.schemas.ManagedObj):
|
||||
class ManagedBLImage(base.ManagedObj):
|
||||
managed_obj_type = ct.ManagedObjType.ManagedBLImage
|
||||
_bl_image_name: str
|
||||
|
||||
|
@ -181,6 +182,9 @@ class ManagedBLImage(ct.schemas.ManagedObj):
|
|||
if bl_image := bpy.data.images.get(self.name):
|
||||
self.preview_space.image = bl_image
|
||||
|
||||
def hide_preview(self) -> None:
|
||||
self.preview_space.image = None
|
||||
|
||||
####################
|
||||
# - Image Geometry
|
||||
####################
|
||||
|
@ -269,12 +273,12 @@ class ManagedBLImage(ct.schemas.ManagedObj):
|
|||
)
|
||||
# log.debug('Computed MPL Geometry (%f)', time.perf_counter() - time_start)
|
||||
|
||||
#log.debug(
|
||||
# 'Creating MPL Axes (aspect=%f, width=%f, height=%f)',
|
||||
# aspect_ratio,
|
||||
# _width_inches,
|
||||
# _height_inches,
|
||||
#)
|
||||
# log.debug(
|
||||
# 'Creating MPL Axes (aspect=%f, width=%f, height=%f)',
|
||||
# aspect_ratio,
|
||||
# _width_inches,
|
||||
# _height_inches,
|
||||
# )
|
||||
# Create MPL Figure, Axes, and Compute Figure Geometry
|
||||
fig, ax = plt.subplots(
|
||||
figsize=[_width_inches, _height_inches],
|
||||
|
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
from ....utils import logger
|
||||
from .. import contracts as ct
|
||||
from .managed_bl_collection import managed_collection, preview_collection
|
||||
from . import base
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -14,7 +15,7 @@ log = logger.get(__name__)
|
|||
####################
|
||||
# - BLMesh
|
||||
####################
|
||||
class ManagedBLMesh(ct.schemas.ManagedObj):
|
||||
class ManagedBLMesh(base.ManagedObj):
|
||||
managed_obj_type = ct.ManagedObjType.ManagedBLMesh
|
||||
_bl_object_name: str | None = None
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import typing_extensions as typx
|
|||
from ....utils import analyze_geonodes, logger
|
||||
from .. import bl_socket_map
|
||||
from .. import contracts as ct
|
||||
from . import base
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -160,7 +161,7 @@ def write_modifier(
|
|||
####################
|
||||
# - ManagedObj
|
||||
####################
|
||||
class ManagedBLModifier(ct.schemas.ManagedObj):
|
||||
class ManagedBLModifier(base.ManagedObj):
|
||||
managed_obj_type = ct.ManagedObjType.ManagedBLModifier
|
||||
_modifier_name: str | None = None
|
||||
|
||||
|
@ -185,6 +186,9 @@ class ManagedBLModifier(ct.schemas.ManagedObj):
|
|||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def bl_select(self) -> None: pass
|
||||
def hide_preview(self) -> None: pass
|
||||
|
||||
####################
|
||||
# - Deallocation
|
||||
####################
|
||||
|
|
|
@ -55,8 +55,8 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
presets: typ.ClassVar = MappingProxyType({})
|
||||
|
||||
# Managed Objects
|
||||
managed_obj_defs: typ.ClassVar[
|
||||
dict[ct.ManagedObjName, ct.schemas.ManagedObjDef]
|
||||
managed_obj_types: typ.ClassVar[
|
||||
dict[ct.ManagedObjName, type[_managed_objs.ManagedObj]]
|
||||
] = MappingProxyType({})
|
||||
|
||||
####################
|
||||
|
@ -222,7 +222,7 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
####################
|
||||
@events.on_value_changed(
|
||||
prop_name='sim_node_name',
|
||||
props={'sim_node_name', 'managed_objs', 'managed_obj_defs'},
|
||||
props={'sim_node_name', 'managed_objs', 'managed_obj_types'},
|
||||
)
|
||||
def _on_sim_node_name_changed(self, props: dict):
|
||||
log.info(
|
||||
|
@ -233,9 +233,8 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
)
|
||||
|
||||
# Set Name of Managed Objects
|
||||
for mobj_id, mobj in props['managed_objs'].items():
|
||||
mobj_def = props['managed_obj_defs'][mobj_id]
|
||||
mobj.name = mobj_def.name_prefix + props['sim_node_name']
|
||||
for mobj in props['managed_objs'].values():
|
||||
mobj.name = props['sim_node_name']
|
||||
|
||||
@events.on_value_changed(prop_name='active_socket_set')
|
||||
def _on_socket_set_changed(self):
|
||||
|
@ -282,9 +281,7 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
def _on_preview_changed(self, props):
|
||||
if not props['preview_active']:
|
||||
for mobj in self.managed_objs.values():
|
||||
if isinstance(mobj, _managed_objs.ManagedBLMesh):
|
||||
## TODO: This is a Workaround
|
||||
mobj.hide_preview()
|
||||
mobj.hide_preview()
|
||||
|
||||
@events.on_enable_lock()
|
||||
def _on_enabled_lock(self):
|
||||
|
@ -460,33 +457,17 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
####################
|
||||
# - Managed Objects
|
||||
####################
|
||||
managed_bl_meshes: dict[str, _managed_objs.ManagedBLMesh] = bl_cache.BLField({})
|
||||
managed_bl_images: dict[str, _managed_objs.ManagedBLImage] = bl_cache.BLField({})
|
||||
managed_bl_modifiers: dict[str, _managed_objs.ManagedBLModifier] = bl_cache.BLField(
|
||||
{}
|
||||
)
|
||||
|
||||
@bl_cache.cached_bl_property(
|
||||
persist=False
|
||||
) ## Disable broken ManagedObj union DECODER
|
||||
@bl_cache.cached_bl_property(persist=True)
|
||||
def managed_objs(self) -> dict[str, _managed_objs.ManagedObj]:
|
||||
"""Access the managed objects defined on this node.
|
||||
|
||||
Persistent cache ensures that the managed objects are only created on first access, even across file reloads.
|
||||
"""
|
||||
if self.managed_obj_defs:
|
||||
if not (
|
||||
managed_objs := (
|
||||
self.managed_bl_meshes
|
||||
| self.managed_bl_images
|
||||
| self.managed_bl_modifiers
|
||||
)
|
||||
):
|
||||
return {
|
||||
mobj_name: mobj_def.mk(mobj_def.name_prefix + self.sim_node_name)
|
||||
for mobj_name, mobj_def in self.managed_obj_defs.items()
|
||||
}
|
||||
return managed_objs
|
||||
if self.managed_obj_types:
|
||||
return {
|
||||
mobj_name: mobj_type(self.sim_node_name)
|
||||
for mobj_name, mobj_type in self.managed_obj_types.items()
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
@ -564,7 +545,7 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
####################
|
||||
@bl_cache.keyed_cache(
|
||||
exclude={'self', 'optional'},
|
||||
serialize={'unit_system'},
|
||||
encode={'unit_system'},
|
||||
)
|
||||
def _compute_input(
|
||||
self,
|
||||
|
|
|
@ -1,17 +1,44 @@
|
|||
import abc
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
import typing_extensions as typx
|
||||
|
||||
from .....utils import serialize
|
||||
from ....utils import logger
|
||||
from .. import contracts as ct
|
||||
from ..socket_types import SocketType
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
####################
|
||||
# - SocketDef
|
||||
####################
|
||||
class SocketDef(pyd.BaseModel, abc.ABC):
|
||||
socket_type: SocketType
|
||||
|
||||
@abc.abstractmethod
|
||||
def init(self, bl_socket: bpy.types.NodeSocket) -> None:
|
||||
"""Initializes a real Blender node socket from this socket definition."""
|
||||
|
||||
####################
|
||||
# - Serialization
|
||||
####################
|
||||
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
|
||||
return [serialize.TypeID.SocketDef, self.__class__.__name__, self.model_dump()]
|
||||
|
||||
@staticmethod
|
||||
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
|
||||
return SocketDef.__subclasses__[obj[1]](**obj[2])
|
||||
|
||||
|
||||
####################
|
||||
# - SocketDef
|
||||
####################
|
||||
class MaxwellSimSocket(bpy.types.NodeSocket):
|
||||
# Fundamentals
|
||||
socket_type: ct.SocketType
|
||||
|
|
|
@ -5,6 +5,8 @@ import typing as typ
|
|||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
SympyType = sp.Basic | sp.Expr | sp.MatrixBase | sp.Quantity
|
||||
|
||||
|
||||
####################
|
||||
# - Useful Methods
|
||||
|
|
|
@ -1,57 +1,131 @@
|
|||
"""
|
||||
|
||||
Attributes:
|
||||
NaiveEncodableType:
|
||||
See <https://jcristharif.com/msgspec/supported-types.html> for details.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import datetime as dt
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
import uuid
|
||||
|
||||
import msgspec
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from . import extra_sympy_units as spux
|
||||
from . import logger
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
EncodableValue: typ.TypeAlias = typ.Any ## msgspec-compatible
|
||||
####################
|
||||
# - Serialization Types
|
||||
####################
|
||||
NaivelyEncodableType: typ.TypeAlias = (
|
||||
None
|
||||
| bool
|
||||
| int
|
||||
| float
|
||||
| str
|
||||
| bytes
|
||||
| bytearray
|
||||
## NO SUPPORT:
|
||||
# | memoryview
|
||||
| tuple
|
||||
| list
|
||||
| dict
|
||||
| set
|
||||
| frozenset
|
||||
## NO SUPPORT:
|
||||
# | typ.Literal
|
||||
| typ.Collection
|
||||
## NO SUPPORT:
|
||||
# | typ.Sequence ## -> list
|
||||
# | typ.MutableSequence ## -> list
|
||||
# | typ.AbstractSet ## -> set
|
||||
# | typ.MutableSet ## -> set
|
||||
# | typ.Mapping ## -> dict
|
||||
# | typ.MutableMapping ## -> dict
|
||||
| typ.TypedDict
|
||||
| typ.NamedTuple
|
||||
| dt.datetime
|
||||
| dt.date
|
||||
| dt.time
|
||||
| dt.timedelta
|
||||
| uuid.UUID
|
||||
| decimal.Decimal
|
||||
## NO SUPPORT:
|
||||
# | enum.Enum
|
||||
| enum.IntEnum
|
||||
| enum.Flag
|
||||
| enum.IntFlag
|
||||
| dataclasses.dataclass
|
||||
| typ.Optional
|
||||
| typ.Union
|
||||
| typ.NewType
|
||||
| typ.TypeAlias
|
||||
| typ.TypeAliasType
|
||||
| typ.Generic
|
||||
| typ.TypeVar
|
||||
| typ.Final
|
||||
| msgspec.Raw
|
||||
## NO SUPPORT:
|
||||
# | msgspec.UNSET
|
||||
)
|
||||
_NaivelyEncodableTypeSet = frozenset(typ.get_args(NaivelyEncodableType))
|
||||
|
||||
|
||||
class TypeID(enum.StrEnum):
|
||||
Complex: str = '!type=complex'
|
||||
SympyType: str = '!type=sympytype'
|
||||
SocketDef: str = '!type=socketdef'
|
||||
ManagedObj: str = '!type=managedobj'
|
||||
|
||||
|
||||
NaiveRepresentation: typ.TypeAlias = list[TypeID, str | None, typ.Any]
|
||||
|
||||
|
||||
def is_representation(obj: NaivelyEncodableType) -> bool:
|
||||
return isinstance(obj, list) and obj[0] in set(TypeID) and len(obj) == 3 # noqa: PLR2004
|
||||
|
||||
|
||||
####################
|
||||
# - (De)Serialization
|
||||
# - Serialization Hooks
|
||||
####################
|
||||
EncodedComplex: typ.TypeAlias = tuple[float, float] | list[float, float]
|
||||
EncodedSympy: typ.TypeAlias = str
|
||||
EncodedManagedObj: typ.TypeAlias = tuple[str, str] | list[str, str]
|
||||
EncodedPydanticModel: typ.TypeAlias = tuple[str, str] | list[str, str]
|
||||
|
||||
|
||||
def _enc_hook(obj: typ.Any) -> EncodableValue:
|
||||
def _enc_hook(obj: typ.Any) -> NaivelyEncodableType:
|
||||
"""Translates types not natively supported by `msgspec`, to an encodable form supported by `msgspec`.
|
||||
|
||||
Parameters:
|
||||
obj: The object of arbitrary type to transform into an encodable value.
|
||||
obj: The object of arbitrary type to transform into an encodable value.
|
||||
|
||||
Returns:
|
||||
A value encodable by `msgspec`.
|
||||
A value encodable by `msgspec`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: When the type transformation hasn't been implemented.
|
||||
NotImplementedError: When the type transformation hasn't been implemented.
|
||||
"""
|
||||
if isinstance(obj, complex):
|
||||
return (obj.real, obj.imag)
|
||||
if isinstance(obj, sp.Basic | sp.MatrixBase | sp.Expr | spu.Quantity):
|
||||
return sp.srepr(obj)
|
||||
if isinstance(obj, managed_objs.ManagedObj):
|
||||
return (obj.name, obj.__class__.__name__)
|
||||
if isinstance(obj, ct.schemas.SocketDef):
|
||||
return (obj.model_dump(), obj.__class__.__name__)
|
||||
return ['!type=complex', None, (obj.real, obj.imag)]
|
||||
|
||||
if isinstance(obj, spux.SympyType):
|
||||
return ['!type=sympytype', None, sp.srepr(obj)]
|
||||
|
||||
if hasattr(obj, 'dump_as_msgspec'):
|
||||
return obj.dump_as_msgspec()
|
||||
|
||||
msg = f'Can\'t encode "{obj}" of type {type(obj)}'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any:
|
||||
def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
|
||||
"""Translates the `msgspec`-encoded form of an object back to its true form.
|
||||
|
||||
Parameters:
|
||||
_type: The type to transform the `msgspec`-encoded object back into.
|
||||
obj: The encoded object of to transform back into an encodable value.
|
||||
obj: The naively decoded object to transform back into its actual type.
|
||||
|
||||
Returns:
|
||||
A value encodable by `msgspec`.
|
||||
|
@ -59,72 +133,31 @@ def _dec_hook(_type: type, obj: EncodableValue) -> typ.Any:
|
|||
Raises:
|
||||
NotImplementedError: When the type transformation hasn't been implemented.
|
||||
"""
|
||||
if _type is complex and isinstance(obj, EncodedComplex):
|
||||
return complex(obj[0], obj[1])
|
||||
if (
|
||||
_type is sp.Basic
|
||||
and isinstance(obj, EncodedSympy)
|
||||
or _type is sp.Expr
|
||||
and isinstance(obj, EncodedSympy)
|
||||
or _type is sp.MatrixBase
|
||||
and isinstance(obj, EncodedSympy)
|
||||
or _type is spu.Quantity
|
||||
and isinstance(obj, EncodedSympy)
|
||||
if _type is complex or (is_representation(obj) and obj[0] == TypeID.Complex):
|
||||
obj_value = obj[2]
|
||||
return complex(obj_value[0], obj_value[1])
|
||||
|
||||
if _type in typ.get_args(spux.SympyType) or (
|
||||
is_representation(obj) and obj[0] == TypeID.SympyType
|
||||
):
|
||||
return sp.sympify(obj).subs(spux.ALL_UNIT_SYMBOLS)
|
||||
if (
|
||||
_type is managed_objs.ManagedBLMesh
|
||||
and isinstance(obj, EncodedManagedObj)
|
||||
or _type is managed_objs.ManagedBLImage
|
||||
and isinstance(obj, EncodedManagedObj)
|
||||
or _type is managed_objs.ManagedBLModifier
|
||||
and isinstance(obj, EncodedManagedObj)
|
||||
):
|
||||
return {
|
||||
'ManagedBLMesh': managed_objs.ManagedBLMesh,
|
||||
'ManagedBLImage': managed_objs.ManagedBLImage,
|
||||
'ManagedBLModifier': managed_objs.ManagedBLModifier,
|
||||
}[obj[1]](obj[0])
|
||||
if _type is ct.schemas.SocketDef:
|
||||
return getattr(sockets, obj[1])(**obj[0])
|
||||
obj_value = obj[2]
|
||||
return sp.sympify(obj_value).subs(spux.ALL_UNIT_SYMBOLS)
|
||||
|
||||
if hasattr(obj, 'parse_as_msgspec'):
|
||||
return _type.parse_as_msgspec(obj)
|
||||
|
||||
msg = f'Can\'t decode "{obj}" to type {type(obj)}'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic')
|
||||
|
||||
_DECODERS: dict[type, msgspec.json.Decoder] = {
|
||||
complex: msgspec.json.Decoder(type=complex, dec_hook=_dec_hook),
|
||||
sp.Basic: msgspec.json.Decoder(type=sp.Basic, dec_hook=_dec_hook),
|
||||
sp.Expr: msgspec.json.Decoder(type=sp.Expr, dec_hook=_dec_hook),
|
||||
sp.MatrixBase: msgspec.json.Decoder(type=sp.MatrixBase, dec_hook=_dec_hook),
|
||||
spu.Quantity: msgspec.json.Decoder(type=spu.Quantity, dec_hook=_dec_hook),
|
||||
managed_objs.ManagedBLMesh: msgspec.json.Decoder(
|
||||
type=managed_objs.ManagedBLMesh,
|
||||
dec_hook=_dec_hook,
|
||||
),
|
||||
managed_objs.ManagedBLImage: msgspec.json.Decoder(
|
||||
type=managed_objs.ManagedBLImage,
|
||||
dec_hook=_dec_hook,
|
||||
),
|
||||
managed_objs.ManagedBLModifier: msgspec.json.Decoder(
|
||||
type=managed_objs.ManagedBLModifier,
|
||||
dec_hook=_dec_hook,
|
||||
),
|
||||
# managed_objs.ManagedObj: msgspec.json.Decoder(
|
||||
# type=managed_objs.ManagedObj, dec_hook=_dec_hook
|
||||
# ), ## Doesn't work b/c unions are not explicit
|
||||
ct.schemas.SocketDef: msgspec.json.Decoder(
|
||||
type=ct.schemas.SocketDef,
|
||||
dec_hook=_dec_hook,
|
||||
),
|
||||
}
|
||||
_DECODER_FALLBACK: msgspec.json.Decoder = msgspec.json.Decoder(dec_hook=_dec_hook)
|
||||
####################
|
||||
# - Global Encoders / Decoders
|
||||
####################
|
||||
_ENCODER = msgspec.json.Encoder(enc_hook=_enc_hook, order='deterministic')
|
||||
|
||||
|
||||
@functools.cache
|
||||
def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802
|
||||
def _DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802
|
||||
"""Retrieve a suitable `msgspec.json.Decoder` by-type.
|
||||
|
||||
Parameters:
|
||||
|
@ -133,21 +166,15 @@ def DECODER(_type: type) -> msgspec.json.Decoder: # noqa: N802
|
|||
Returns:
|
||||
A suitable decoder.
|
||||
"""
|
||||
if (decoder := _DECODERS.get(_type)) is not None:
|
||||
return decoder
|
||||
|
||||
return _DECODER_FALLBACK
|
||||
return msgspec.json.Decoder(type=_type, dec_hook=_dec_hook)
|
||||
|
||||
|
||||
def decode_any(_type: type, obj: str) -> typ.Any:
|
||||
naive_decode = DECODER(_type).decode(obj)
|
||||
if _type == dict[str, ct.schemas.SocketDef]:
|
||||
return {
|
||||
socket_name: getattr(sockets, socket_def_list[1])(**socket_def_list[0])
|
||||
for socket_name, socket_def_list in naive_decode.items()
|
||||
}
|
||||
####################
|
||||
# - Encoder / Decoder Functions
|
||||
####################
|
||||
def encode(obj: typ.Any) -> bytes:
|
||||
return _ENCODER.encode(obj)
|
||||
|
||||
log.critical(
|
||||
'Naive Decode of "%s" to "%s" (%s)', str(obj), str(naive_decode), str(_type)
|
||||
)
|
||||
return naive_decode
|
||||
|
||||
def decode(_type: type, obj: str | bytes) -> typ.Any:
|
||||
return _DECODER(_type).decode(obj)
|
||||
|
|
Loading…
Reference in New Issue