123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- """ A unified interface for several common serialization methods """
- from abc import ABC, abstractmethod
- from typing import Any, Dict
- import msgpack
- from hivemind.utils.logging import get_logger
- logger = get_logger(__name__)
- class SerializerBase(ABC):
- @staticmethod
- @abstractmethod
- def dumps(obj: object) -> bytes:
- pass
- @staticmethod
- @abstractmethod
- def loads(buf: bytes) -> object:
- pass
- class MSGPackSerializer(SerializerBase):
- _ext_types: Dict[Any, int] = {}
- _ext_type_codes: Dict[int, Any] = {}
- _TUPLE_EXT_TYPE_CODE = 0x40
- @classmethod
- def ext_serializable(cls, type_code: int):
- assert isinstance(type_code, int), "Please specify a (unique) int type code"
- def wrap(wrapped_type: type):
- assert callable(getattr(wrapped_type, "packb", None)) and callable(
- getattr(wrapped_type, "unpackb", None)
- ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
- if type_code in cls._ext_type_codes:
- logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting")
- cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
- return wrapped_type
- return wrap
- @classmethod
- def _encode_ext_types(cls, obj):
- type_code = cls._ext_types.get(type(obj))
- if type_code is not None:
- return msgpack.ExtType(type_code, obj.packb())
- elif isinstance(obj, tuple):
- # Tuples need to be handled separately to ensure that
- # 1. tuple serialization works and 2. tuples serialized not as lists
- data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
- return msgpack.ExtType(cls._TUPLE_EXT_TYPE_CODE, data)
- return obj
- @classmethod
- def _decode_ext_types(cls, type_code: int, data: bytes):
- if type_code in cls._ext_type_codes:
- return cls._ext_type_codes[type_code].unpackb(data)
- elif type_code == cls._TUPLE_EXT_TYPE_CODE:
- return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
- logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is")
- return data
- @classmethod
- def dumps(cls, obj: object) -> bytes:
- return msgpack.dumps(obj, use_bin_type=True, default=cls._encode_ext_types, strict_types=True)
- @classmethod
- def loads(cls, buf: bytes) -> object:
- return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)
|