serializer.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """ A unified interface for several common serialization methods """
  2. from abc import ABC, abstractmethod
  3. from typing import Any, Dict
  4. import msgpack
  5. from hivemind.utils.logging import get_logger
  6. logger = get_logger(__name__)
  7. class SerializerBase(ABC):
  8. @staticmethod
  9. @abstractmethod
  10. def dumps(obj: object) -> bytes:
  11. pass
  12. @staticmethod
  13. @abstractmethod
  14. def loads(buf: bytes) -> object:
  15. pass
  16. class MSGPackSerializer(SerializerBase):
  17. _ext_types: Dict[Any, int] = {}
  18. _ext_type_codes: Dict[int, Any] = {}
  19. _TUPLE_EXT_TYPE_CODE = 0x40
  20. @classmethod
  21. def ext_serializable(cls, type_code: int):
  22. assert isinstance(type_code, int), "Please specify a (unique) int type code"
  23. def wrap(wrapped_type: type):
  24. assert callable(getattr(wrapped_type, "packb", None)) and callable(
  25. getattr(wrapped_type, "unpackb", None)
  26. ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
  27. if type_code in cls._ext_type_codes:
  28. logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting")
  29. cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
  30. return wrapped_type
  31. return wrap
  32. @classmethod
  33. def _encode_ext_types(cls, obj):
  34. type_code = cls._ext_types.get(type(obj))
  35. if type_code is not None:
  36. return msgpack.ExtType(type_code, obj.packb())
  37. elif isinstance(obj, tuple):
  38. # Tuples need to be handled separately to ensure that
  39. # 1. tuple serialization works and 2. tuples serialized not as lists
  40. data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
  41. return msgpack.ExtType(cls._TUPLE_EXT_TYPE_CODE, data)
  42. return obj
  43. @classmethod
  44. def _decode_ext_types(cls, type_code: int, data: bytes):
  45. if type_code in cls._ext_type_codes:
  46. return cls._ext_type_codes[type_code].unpackb(data)
  47. elif type_code == cls._TUPLE_EXT_TYPE_CODE:
  48. return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
  49. logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is")
  50. return data
  51. @classmethod
  52. def dumps(cls, obj: object) -> bytes:
  53. return msgpack.dumps(obj, use_bin_type=True, default=cls._encode_ext_types, strict_types=True)
  54. @classmethod
  55. def loads(cls, buf: bytes) -> object:
  56. return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)