Преглед на файлове

Serialize DHTID source with msgpack (#172)

* Change DHTID serializer

* Remove unused serializers

* Add msgpack tuple serialization
Alexey Bukhtiyarov преди 4 години
родител
ревизия
06162992fa
променени са 5 файла, в които са добавени 27 реда и са изтрити 28 реда
  1. 1 1
      hivemind/client/averaging/matchmaking.py
  2. 2 2
      hivemind/dht/routing.py
  3. 9 23
      hivemind/utils/serializer.py
  4. 0 2
      tests/test_routing.py
  5. 15 0
      tests/test_util_modules.py

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -487,7 +487,7 @@ def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
     schema_dicts = [{field_name: str(field_value)
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
-    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
+    return DHTID.generate(source=schema_dicts).to_bytes()
 
 
 class MatchmakingException(Exception):

+ 2 - 2
hivemind/dht/routing.py

@@ -8,7 +8,7 @@ import random
 from collections.abc import Iterable
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
-from hivemind.utils import Endpoint, PickleSerializer, get_dht_time, DHTExpiration
+from hivemind.utils import Endpoint, MSGPackSerializer, get_dht_time, DHTExpiration
 
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 
@@ -255,7 +255,7 @@ class DHTID(int):
             by default, generates a random dhtid from :nbits: random bits
         """
         source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
-        source = PickleSerializer.dumps(source) if not isinstance(source, bytes) else source
+        source = MSGPackSerializer.dumps(source) if not isinstance(source, bytes) else source
         raw_uid = cls.HASH_FUNC(source).digest()
         return cls(int(raw_uid.hex(), 16))
 

+ 9 - 23
hivemind/utils/serializer.py

@@ -1,5 +1,4 @@
 """ A unified interface for several common serialization methods """
-import pickle
 from io import BytesIO
 from typing import Dict, Any
 
@@ -20,31 +19,10 @@ class SerializerBase:
         raise NotImplementedError()
 
 
-class PickleSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return pickle.loads(buf)
-
-
-class PytorchSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        s = BytesIO()
-        torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
-        return s.getvalue()
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return torch.load(BytesIO(buf))
-
-
 class MSGPackSerializer(SerializerBase):
     _ExtTypes: Dict[Any, int] = {}
     _ExtTypeCodes: Dict[int, Any] = {}
+    _MsgpackExtTypeCodeTuple = 0x40
 
     @classmethod
     def ext_serializable(cls, type_code: int):
@@ -64,12 +42,20 @@ class MSGPackSerializer(SerializerBase):
         type_code = cls._ExtTypes.get(type(obj))
         if type_code is not None:
             return msgpack.ExtType(type_code, obj.packb())
+        elif isinstance(obj, tuple):
+            # Tuples need to be handled separately to ensure that
+            # 1. tuple serialization works and 2. tuples serialized not as lists
+            data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
+            return msgpack.ExtType(cls._MsgpackExtTypeCodeTuple, data)
         return obj
 
     @classmethod
     def _decode_ext_types(cls, type_code: int, data: bytes):
         if type_code in cls._ExtTypeCodes:
             return cls._ExtTypeCodes[type_code].unpackb(data)
+        elif type_code == cls._MsgpackExtTypeCodeTuple:
+            return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
+
         logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
         return data
 

+ 0 - 2
tests/test_routing.py

@@ -5,7 +5,6 @@ from itertools import chain, zip_longest
 
 from hivemind import LOCALHOST
 from hivemind.dht.routing import RoutingTable, DHTID
-from hivemind.utils.serializer import PickleSerializer
 
 
 def test_ids_basic():
@@ -15,7 +14,6 @@ def test_ids_basic():
         assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
         assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
         assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
-        assert len(PickleSerializer.dumps(id1)) - len(PickleSerializer.dumps(int(id1))) < 40
         assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
 
 

+ 15 - 0
tests/test_util_modules.py

@@ -6,6 +6,7 @@ import pytest
 import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
+from hivemind.utils import MSGPackSerializer
 from concurrent.futures import CancelledError
 
 
@@ -194,6 +195,20 @@ def test_serialize_tensor():
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
 
 
+def test_serialize_tuple():
+    test_pairs = (
+        ((1, 2, 3), [1, 2, 3]),
+        (('1', False, 0), ['1', False, 0]),
+        (('1', False, 0), ('1', 0, 0)),
+        (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
+    )
+
+    for first, second in test_pairs:
+        assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
+        assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
+        assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
+
+
 def test_split_parts():
     tensor = torch.randn(910, 512)
     serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)