ソースを参照

Move Averager metadata serialization out of user scope (#168)

* move metadata serialization outside user scope
* test_overcrowded: reduce the default number of peers
justheuristic 4 年 前
コミット
25fdf0d94f
2 ファイル変更15 行追加11 行削除
  1. 12 8
      hivemind/client/averaging/__init__.py
  2. 3 3
      tests/test_averaging.py

+ 12 - 8
hivemind/client/averaging/__init__.py

@@ -23,7 +23,7 @@ from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
     serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
-from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils import Endpoint, Port, MPFuture, get_logger
 
 # flavour types
@@ -126,8 +126,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._averager_endpoint: Optional[Endpoint] = None
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
-        background_fetcher = threading.Thread(daemon=True, target=_background_thread_fetch_current_state,
-                                              args=[self.pipe, weakref.WeakMethod(self.get_current_state)])
+        background_fetcher = threading.Thread(
+            daemon=True, target=_background_thread_fetch_current_state,
+            args=[self.serializer, self.pipe, weakref.WeakMethod(self.get_current_state)])
         background_fetcher.start()
         if start:
             self.run_in_background(await_ready=True)
@@ -326,13 +327,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
-    def get_current_state(self) -> Tuple[bytes, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
-        :returns: a tuple of (serialized_metadata, sequence of torch tensors)
+        :returns: a tuple of (small metadata, sequence of torch tensors)
+        :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         """
         with self.get_tensors() as tensors:
-            return self.serializer.dumps(dict(group_key=self.get_group_bits())), tensors
+            return dict(group_key=self.get_group_bits()), tensors
 
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
@@ -433,9 +435,11 @@ def is_power_of_two(n):
     return (n != 0) and (n & (n - 1) == 0)
 
 
-def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod):
+def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.connection.Connection,
+                                           get_current_state_ref: weakref.WeakMethod):
     """
     Executed in the host process as a background thread. Fetches the averager state when asked by peers.
+    :param serializer: a serializer with which to convert metadata into bytes
     :param pipe: DecentralizedAverager's control pipe (from host process side)
     :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
     """
@@ -452,7 +456,7 @@ def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_c
             state_metadata, state_tensors = get_current_state()
             del get_current_state
 
-            assert isinstance(state_metadata, bytes)
+            state_metadata = serializer.dumps(state_metadata)
             state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
                                   for tensor in state_tensors)
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process

+ 3 - 3
tests/test_averaging.py

@@ -267,13 +267,13 @@ def test_too_few_peers():
 
 
 @pytest.mark.forked
-def test_overcrowded():
+def test_overcrowded(num_peers=16):
     dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
     averagers = [hivemind.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         averaging_expiration=1, request_timeout=0.5,
         prefix='mygroup', initial_group_bits='', start=True)
-        for _ in range(32)]
+        for _ in range(num_peers)]
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
@@ -297,7 +297,7 @@ def test_load_state_from_peers():
             """
             nonlocal num_calls, super_metadata, super_tensors
             num_calls += 1
-            return self.serializer.dumps(super_metadata), super_tensors
+            return super_metadata, super_tensors
 
     dht_root = hivemind.DHT(start=True)
     initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']