ソースを参照

Convert SerializerBase to an abstract class (#212)

* Make SerializerBase an abstract class

* Remove redundant imports in connection_handler.py
Max Ryabinin 4 年 前
コミット
1d364b7c32
2 ファイル変更18 行追加17 行削除
  1. 0 2
      hivemind/server/connection_handler.py
  2. 18 15
      hivemind/utils/serializer.py

+ 0 - 2
hivemind/server/connection_handler.py

@@ -1,4 +1,3 @@
-import asyncio
 import multiprocessing as mp
 import os
 import pickle
@@ -6,7 +5,6 @@ from typing import Dict
 
 import grpc
 import torch
-import uvloop
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend

+ 18 - 15
hivemind/utils/serializer.py

@@ -1,5 +1,6 @@
 """ A unified interface for several common serialization methods """
 from typing import Dict, Any
+from abc import ABC, abstractmethod
 
 import msgpack
 
@@ -8,51 +9,54 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-class SerializerBase:
+class SerializerBase(ABC):
     @staticmethod
+    @abstractmethod
     def dumps(obj: object) -> bytes:
-        raise NotImplementedError()
+        pass
 
     @staticmethod
+    @abstractmethod
     def loads(buf: bytes) -> object:
-        raise NotImplementedError()
+        pass
 
 
 class MSGPackSerializer(SerializerBase):
-    _ExtTypes: Dict[Any, int] = {}
-    _ExtTypeCodes: Dict[int, Any] = {}
-    _MsgpackExtTypeCodeTuple = 0x40
+    _ext_types: Dict[Any, int] = {}
+    _ext_type_codes: Dict[int, Any] = {}
+    _TUPLE_EXT_TYPE_CODE = 0x40
 
     @classmethod
     def ext_serializable(cls, type_code: int):
         assert isinstance(type_code, int), "Please specify a (unique) int type code"
 
         def wrap(wrapped_type: type):
-            assert callable(getattr(wrapped_type, 'packb', None)) and callable(getattr(wrapped_type, 'unpackb', None)),\
+            assert callable(getattr(wrapped_type, 'packb', None)) and callable(getattr(wrapped_type, 'unpackb', None)), \
                 f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
-            if type_code in cls._ExtTypeCodes:
+            if type_code in cls._ext_type_codes:
                 logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
-            cls._ExtTypeCodes[type_code], cls._ExtTypes[wrapped_type] = wrapped_type, type_code
+            cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
             return wrapped_type
+
         return wrap
 
     @classmethod
     def _encode_ext_types(cls, obj):
-        type_code = cls._ExtTypes.get(type(obj))
+        type_code = cls._ext_types.get(type(obj))
         if type_code is not None:
             return msgpack.ExtType(type_code, obj.packb())
         elif isinstance(obj, tuple):
             # Tuples need to be handled separately to ensure that
             # 1. tuple serialization works and 2. tuples serialized not as lists
             data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
-            return msgpack.ExtType(cls._MsgpackExtTypeCodeTuple, data)
+            return msgpack.ExtType(cls._TUPLE_EXT_TYPE_CODE, data)
         return obj
 
     @classmethod
     def _decode_ext_types(cls, type_code: int, data: bytes):
-        if type_code in cls._ExtTypeCodes:
-            return cls._ExtTypeCodes[type_code].unpackb(data)
-        elif type_code == cls._MsgpackExtTypeCodeTuple:
+        if type_code in cls._ext_type_codes:
+            return cls._ext_type_codes[type_code].unpackb(data)
+        elif type_code == cls._TUPLE_EXT_TYPE_CODE:
             return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
 
         logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
@@ -65,4 +69,3 @@ class MSGPackSerializer(SerializerBase):
     @classmethod
     def loads(cls, buf: bytes) -> object:
         return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)
-