|
@@ -12,6 +12,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
|
|
|
|
import grpc
|
|
|
+from grpc._cython.cygrpc import InternalError
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
|
|
@@ -239,10 +240,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
|
|
|
future.set_result(gathered_data_by_peer)
|
|
|
|
|
|
- except (AllreduceException, MatchmakingException):
|
|
|
+ except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError,
|
|
|
+ grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
|
|
|
time_elapsed = get_dht_time() - start_time
|
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
future.set_result(None)
|
|
|
+ else:
|
|
|
+ logger.debug(f"caught {e}, retrying")
|
|
|
|
|
|
except Exception as e:
|
|
|
future.set_exception(e)
|
|
@@ -311,9 +315,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
) -> AsyncIterator[averaging_pb2.DownloadData]:
|
|
|
"""
|
|
|
Get the up-to-date trainer state from a peer.
|
|
|
- The state consists of two parts: (metadata, tensors)
|
|
|
+ The state consists of two parts: (serialized_metadata, tensors)
|
|
|
|
|
|
- - metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
|
+ - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
|
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
|
"""
|
|
|
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
|
|
@@ -342,15 +346,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
|
|
|
return await future
|
|
|
|
|
|
- def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
|
|
|
+ def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
"""
|
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
|
- :returns: on success, return a 2-tuple with (serialized_metadata, tensors), where
|
|
|
+ :returns: on success, return a 2-tuple with (metadata, tensors), where
|
|
|
|
|
|
- - serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters)
|
|
|
+ - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
|
|
|
- tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
|
|
|
|
|
|
- The exact contents of both serialized_metadata and tensors are determined by get_current_state method
|
|
|
+ The exact contents of both metadata and tensors are determined by get_current_state method
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
|
|
@@ -441,7 +445,7 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
|
|
|
Executed in the host process as a background thread. Fetches the averager state when asked by peers.
|
|
|
:param serializer: a serializer with which to convert metadata into bytes
|
|
|
:param pipe: DecentralizedAverager's control pipe (from host process side)
|
|
|
- :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
|
|
|
+ :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
|
|
|
"""
|
|
|
while True:
|
|
|
trigger, future = pipe.recv()
|