ソースを参照

Serialize DHTID source with msgpack (#172)

* Change DHTID serializer

* Remove unused serializers

* Add msgpack tuple serialization
Alexey Bukhtiyarov 4 年 前
コミット
06162992fa

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -487,7 +487,7 @@ def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
     schema_dicts = [{field_name: str(field_value)
     schema_dicts = [{field_name: str(field_value)
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
                     for tensor in tensors]
-    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
+    return DHTID.generate(source=schema_dicts).to_bytes()
 
 
 
 
 class MatchmakingException(Exception):
 class MatchmakingException(Exception):

+ 2 - 2
hivemind/dht/routing.py

@@ -8,7 +8,7 @@ import random
 from collections.abc import Iterable
 from collections.abc import Iterable
 from itertools import chain
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
-from hivemind.utils import Endpoint, PickleSerializer, get_dht_time, DHTExpiration
+from hivemind.utils import Endpoint, MSGPackSerializer, get_dht_time, DHTExpiration
 
 
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 
 
@@ -255,7 +255,7 @@ class DHTID(int):
             by default, generates a random dhtid from :nbits: random bits
             by default, generates a random dhtid from :nbits: random bits
         """
         """
         source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
         source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
-        source = PickleSerializer.dumps(source) if not isinstance(source, bytes) else source
+        source = MSGPackSerializer.dumps(source) if not isinstance(source, bytes) else source
         raw_uid = cls.HASH_FUNC(source).digest()
         raw_uid = cls.HASH_FUNC(source).digest()
         return cls(int(raw_uid.hex(), 16))
         return cls(int(raw_uid.hex(), 16))
 
 

+ 9 - 23
hivemind/utils/serializer.py

@@ -1,5 +1,4 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
-import pickle
 from io import BytesIO
 from io import BytesIO
 from typing import Dict, Any
 from typing import Dict, Any
 
 
@@ -20,31 +19,10 @@ class SerializerBase:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
-class PickleSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return pickle.loads(buf)
-
-
-class PytorchSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        s = BytesIO()
-        torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
-        return s.getvalue()
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return torch.load(BytesIO(buf))
-
-
 class MSGPackSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
     _ExtTypes: Dict[Any, int] = {}
     _ExtTypes: Dict[Any, int] = {}
     _ExtTypeCodes: Dict[int, Any] = {}
     _ExtTypeCodes: Dict[int, Any] = {}
+    _MsgpackExtTypeCodeTuple = 0x40
 
 
     @classmethod
     @classmethod
     def ext_serializable(cls, type_code: int):
     def ext_serializable(cls, type_code: int):
@@ -64,12 +42,20 @@ class MSGPackSerializer(SerializerBase):
         type_code = cls._ExtTypes.get(type(obj))
         type_code = cls._ExtTypes.get(type(obj))
         if type_code is not None:
         if type_code is not None:
             return msgpack.ExtType(type_code, obj.packb())
             return msgpack.ExtType(type_code, obj.packb())
+        elif isinstance(obj, tuple):
+            # Tuples need to be handled separately to ensure that
+            # 1. tuple serialization works and 2. tuples serialized not as lists
+            data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types)
+            return msgpack.ExtType(cls._MsgpackExtTypeCodeTuple, data)
         return obj
         return obj
 
 
     @classmethod
     @classmethod
     def _decode_ext_types(cls, type_code: int, data: bytes):
     def _decode_ext_types(cls, type_code: int, data: bytes):
         if type_code in cls._ExtTypeCodes:
         if type_code in cls._ExtTypeCodes:
             return cls._ExtTypeCodes[type_code].unpackb(data)
             return cls._ExtTypeCodes[type_code].unpackb(data)
+        elif type_code == cls._MsgpackExtTypeCodeTuple:
+            return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
+
         logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
         logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
         return data
         return data
 
 

+ 0 - 2
tests/test_routing.py

@@ -5,7 +5,6 @@ from itertools import chain, zip_longest
 
 
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
 from hivemind.dht.routing import RoutingTable, DHTID
 from hivemind.dht.routing import RoutingTable, DHTID
-from hivemind.utils.serializer import PickleSerializer
 
 
 
 
 def test_ids_basic():
 def test_ids_basic():
@@ -15,7 +14,6 @@ def test_ids_basic():
         assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
         assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
         assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
         assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
         assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
         assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
-        assert len(PickleSerializer.dumps(id1)) - len(PickleSerializer.dumps(int(id1))) < 40
         assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
         assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
 
 
 
 

+ 15 - 0
tests/test_util_modules.py

@@ -6,6 +6,7 @@ import pytest
 import hivemind
 import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
+from hivemind.utils import MSGPackSerializer
 from concurrent.futures import CancelledError
 from concurrent.futures import CancelledError
 
 
 
 
@@ -194,6 +195,20 @@ def test_serialize_tensor():
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
 
 
 
 
+def test_serialize_tuple():
+    test_pairs = (
+        ((1, 2, 3), [1, 2, 3]),
+        (('1', False, 0), ['1', False, 0]),
+        (('1', False, 0), ('1', 0, 0)),
+        (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
+    )
+
+    for first, second in test_pairs:
+        assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
+        assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
+        assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
+
+
 def test_split_parts():
 def test_split_parts():
     tensor = torch.randn(910, 512)
     tensor = torch.randn(910, 512)
     serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)
     serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)