|
@@ -1,5 +1,4 @@
|
|
""" A unified interface for several common serialization methods """
|
|
""" A unified interface for several common serialization methods """
|
|
-import pickle
|
|
|
|
from io import BytesIO
|
|
from io import BytesIO
|
|
from typing import Dict, Any
|
|
from typing import Dict, Any
|
|
|
|
|
|
@@ -20,31 +19,10 @@ class SerializerBase:
|
|
raise NotImplementedError()
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
-class PickleSerializer(SerializerBase):
|
|
|
|
- @staticmethod
|
|
|
|
- def dumps(obj: object) -> bytes:
|
|
|
|
- return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
-
|
|
|
|
- @staticmethod
|
|
|
|
- def loads(buf: bytes) -> object:
|
|
|
|
- return pickle.loads(buf)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class PytorchSerializer(SerializerBase):
|
|
|
|
- @staticmethod
|
|
|
|
- def dumps(obj: object) -> bytes:
|
|
|
|
- s = BytesIO()
|
|
|
|
- torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
- return s.getvalue()
|
|
|
|
-
|
|
|
|
- @staticmethod
|
|
|
|
- def loads(buf: bytes) -> object:
|
|
|
|
- return torch.load(BytesIO(buf))
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class MSGPackSerializer(SerializerBase):
|
|
class MSGPackSerializer(SerializerBase):
|
|
_ExtTypes: Dict[Any, int] = {}
|
|
_ExtTypes: Dict[Any, int] = {}
|
|
_ExtTypeCodes: Dict[int, Any] = {}
|
|
_ExtTypeCodes: Dict[int, Any] = {}
|
|
|
|
+ _MsgpackExtTypeCodeTuple = 0x40
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def ext_serializable(cls, type_code: int):
|
|
def ext_serializable(cls, type_code: int):
|
|
@@ -64,12 +42,20 @@ class MSGPackSerializer(SerializerBase):
|
|
type_code = cls._ExtTypes.get(type(obj))
|
|
type_code = cls._ExtTypes.get(type(obj))
|
|
if type_code is not None:
|
|
if type_code is not None:
|
|
return msgpack.ExtType(type_code, obj.packb())
|
|
return msgpack.ExtType(type_code, obj.packb())
|
|
|
|
+ elif isinstance(obj, tuple):
|
|
|
|
+ # Tuples need to be handled separately to ensure that
|
|
|
|
+ # 1. tuple serialization works and 2. tuples serialized not as lists
|
|
|
|
+ data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
|
|
|
|
+ return msgpack.ExtType(cls._MsgpackExtTypeCodeTuple, data)
|
|
return obj
|
|
return obj
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def _decode_ext_types(cls, type_code: int, data: bytes):
|
|
def _decode_ext_types(cls, type_code: int, data: bytes):
|
|
if type_code in cls._ExtTypeCodes:
|
|
if type_code in cls._ExtTypeCodes:
|
|
return cls._ExtTypeCodes[type_code].unpackb(data)
|
|
return cls._ExtTypeCodes[type_code].unpackb(data)
|
|
|
|
+ elif type_code == cls._MsgpackExtTypeCodeTuple:
|
|
|
|
+ return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
|
|
|
|
+
|
|
logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
|
|
logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
|
|
return data
|
|
return data
|
|
|
|
|