Просмотр исходного кода

Handle edge cases in DecentralizedAverager (#171)

* move metadata serialization outside user scope
* retry averager.step on network errors
* raise AllreduceException on partial tensor
* test split/combine tensors, combine corrupted stream

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 лет назад
Родитель
Сommit
23e655640c

+ 12 - 8
hivemind/client/averaging/__init__.py

@@ -12,6 +12,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 import grpc
+from grpc._cython.cygrpc import InternalError
 import torch
 import numpy as np
 
@@ -239,10 +240,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 future.set_result(gathered_data_by_peer)
 
-            except (AllreduceException, MatchmakingException):
+            except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError,
+                    grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                     future.set_result(None)
+                else:
+                    logger.debug(f"caught {e}, retrying")
 
             except Exception as e:
                 future.set_exception(e)
@@ -311,9 +315,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                                  ) -> AsyncIterator[averaging_pb2.DownloadData]:
         """
         Get the up-to-date trainer state from a peer.
-        The state consists of two parts: (metadata, tensors)
+        The state consists of two parts: (serialized_metadata, tensors)
 
-         - metadata is a small serialized bytestring meant to store scalars and hyperparameters
+         - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
         chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
@@ -342,15 +346,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         return await future
 
-    def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
+    def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
-        :returns: on success, return a 2-tuple with (serialized_metadata, tensors), where
+        :returns: on success, return a 2-tuple with (metadata, tensors), where
 
-        - serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters)
+        - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
         - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
 
-        The exact contents of both serialized_metadata and tensors are determined by get_current_state method
+        The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
@@ -441,7 +445,7 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
     Executed in the host process as a background thread. Fetches the averager state when asked by peers.
     :param serializer: a serializer with which to convert metadata into bytes
     :param pipe: DecentralizedAverager's control pipe (from host process side)
-    :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
+    :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
     """
     while True:
         trigger, future = pipe.recv()

+ 11 - 2
hivemind/client/averaging/allreduce.py

@@ -153,7 +153,12 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
                                      f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                                      f" allreduce failed")
 
-        averaged_part = deserialize_torch_tensor(combine_from_streaming([message.tensor_part for message in outputs]))
+        try:
+            averaged_part = deserialize_torch_tensor(combine_from_streaming(
+                [message.tensor_part for message in outputs]))
+        except RuntimeError as e:
+            raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
+
         self.register_averaged_part(peer_endpoint, averaged_part)
         return averaged_part
 
@@ -182,7 +187,11 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
                                         ) -> Iterable[runtime_pb2.Tensor]:
         """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
-        tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
+        try:
+            tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
+        except RuntimeError as e:
+            raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
+
         averaged_part = await self.accumulate_part(source, tensor_part)
         if not self.averaged_part_stream.done():
             serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)

+ 38 - 1
tests/test_util_modules.py

@@ -1,5 +1,6 @@
 import asyncio
 import torch
+import numpy as np
 
 import pytest
 import hivemind
@@ -57,7 +58,7 @@ def test_mpfuture_cancel():
         with pytest.raises(RuntimeError):
             future.set_result(123)
         with pytest.raises(RuntimeError):
-            future.set_exception(NotImplementedError)
+            future.set_exception(NotImplementedError())
         assert future.cancelled() and future.done() and not future.running()
 
 
@@ -192,6 +193,42 @@ def test_serialize_tensor():
     restored = hivemind.combine_from_streaming(chunks)
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
 
+
+def test_split_parts():
+    tensor = torch.randn(910, 512)
+    serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)
+    chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
+    assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
+
+    chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
+    assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
+
+    chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
+    assert len(chunks3) == 1
+
+    compressed_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16,
+                                                                   allow_inplace=False)
+    chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
+    assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
+
+    combined1 = hivemind.utils.combine_from_streaming(chunks1)
+    combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
+    combined3 = hivemind.utils.combine_from_streaming(chunks3)
+    combined4 = hivemind.utils.combine_from_streaming(chunks4)
+    for combined in combined1, combined2, combined3:
+        assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
+
+    assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
+
+    combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
+    combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
+    combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
+    for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
+        with pytest.raises(RuntimeError):
+            hivemind.deserialize_torch_tensor(combined)
+            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
+
+
 def test_generic_data_classes():
     from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration