瀏覽代碼

Load state from peers in DecentralizedAverager (#154)

* Implemented DecentralizedAverager.load_state_from_peers that attempts to load the training state from another averager
  * The donor averager is chosen in the order from latest successfully updated to earliest
  * The definition of state can be extended by the user (by inheriting from DecentralizedAverager)
* calling __del__ on a partially created MPFuture will no longer cause error (edge case from albert)
image
* DecentralizedAverager now supports manually getting/setting current group bits (used as dht key)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父節點
當前提交
1d1252c30d

+ 0 - 2
docs/user/quickstart.md

@@ -16,8 +16,6 @@ pip install .
 
 
 You can also install it in the editable mode with `pip install -e .`.
 You can also install it in the editable mode with `pip install -e .`.
 
 
-__Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
-
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __OS support:__ Linux and macOS should [just work](https://github.com/learning-at-home/hivemind/issues).
 * __OS support:__ Linux and macOS should [just work](https://github.com/learning-at-home/hivemind/issues).
 We do not officially support Windows, but you are welcome to contribute your windows build :)
 We do not officially support Windows, but you are welcome to contribute your windows build :)

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.28'
+__version__ = '0.8.29'

+ 147 - 7
hivemind/client/averaging/__init__.py

@@ -14,22 +14,25 @@ import torch
 import numpy as np
 import numpy as np
 
 
 import hivemind
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
+from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
+    serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
+from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.serializer import PickleSerializer, MSGPackSerializer
+from hivemind.utils import Endpoint, Port, MPFuture, get_logger
 
 
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-
 DataForGather = Any
 DataForGather = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
+DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
     """
-    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
 
 
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
@@ -103,6 +106,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         self._averaged_tensors = tuple(averaged_tensors)
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
         self.lock_averaged_tensors = mp.Lock()
+        self.last_updated: DHTExpiration = -float('inf')
         for tensor in self._averaged_tensors:
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
             tensor.share_memory_()
@@ -122,6 +126,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
+            hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
 
 
     @property
     @property
     def port(self) -> Optional[Port]:
     def port(self) -> Optional[Port]:
@@ -157,6 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self._pending_group_assembled.set()
             self._pending_group_assembled.set()
             await server.start()
             await server.start()
             self.ready.set()
             self.ready.set()
+            asyncio.create_task(self._declare_for_download_periodically())
 
 
             while True:
             while True:
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
@@ -195,10 +201,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, gather=gather, allow_retries=allow_retries, timeout=timeout)))
+        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
+        self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
+                                          allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
-    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
+    async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
         start_time = get_dht_time()
         start_time = get_dht_time()
         group_id = None
         group_id = None
@@ -206,7 +214,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         while not future.done():
         while not future.done():
             try:
             try:
                 self._pending_group_assembled.clear()
                 self._pending_group_assembled.clear()
-                gather_binary = self.serializer.dumps(gather)
                 allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                 allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                 if allreduce_group is None:
                 if allreduce_group is None:
                     raise AllreduceException("Averaging step failed: could not find a group.")
                     raise AllreduceException("Averaging step failed: could not find a group.")
@@ -245,6 +252,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             assert len(local_tensors) == len(self._averaged_tensors)
             assert len(local_tensors) == len(self._averaged_tensors)
             for tensor, update in zip(local_tensors, averaging_deltas):
             for tensor, update in zip(local_tensors, averaging_deltas):
                 tensor.add_(update, alpha=self._averaging_alpha)
                 tensor.add_(update, alpha=self._averaging_alpha)
+        self.last_updated = get_dht_time()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -255,6 +263,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         """
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors
+        self.last_updated = get_dht_time()
 
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -279,6 +288,137 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
         async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
             yield message
             yield message
 
 
+    async def _declare_for_download_periodically(self):
+        download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
+        while True:
+            asyncio.create_task(asyncio.wait_for(self.dht.store(
+                download_key, subkey=self.endpoint, value=self.last_updated,
+                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                timeout=self._matchmaking.averaging_expiration))
+            await asyncio.sleep(self._matchmaking.averaging_expiration)
+
+    async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.DownloadData]:
+        """
+        Get the up-to-date trainer state from a peer.
+        The state consists of two parts: (metadata, tensors)
+
+         - metadata is a small pickle-serialized entry meant to store scalars and hyperparameters
+         - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
+        """
+        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
+        metadata, tensors = await self._get_current_state_from_host_process()
+
+        for tensor in tensors:
+            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+                if metadata is not None:
+                    yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
+                    metadata = None
+                else:
+                    yield averaging_pb2.DownloadData(tensor_part=part)
+
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+        """
+        Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with self.get_tensors() as tensors:
+            return dict(group_key=self.get_group_bits()), tensors
+
+    async def _get_current_state_from_host_process(self):
+        """ Executed in the averager process inside rpc_download_state """
+        future, _future = MPFuture.make_pair()
+        self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        return await future
+
+    def _background_thread_fetch_current_state_if_asked(self):
+        """ Executed in the host process as a background thread. """
+        while True:
+            trigger, future = self.pipe.recv()
+            assert trigger == '_TRIGGER_GET_CURRENT_STATE'
+            try:
+                state_metadata, state_tensors = self.get_current_state()
+                # note: serialize here to avoid initializing cuda in the guest process
+                state_metadata = PickleSerializer.dumps(state_metadata)
+                state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
+                                      for tensor in state_tensors)
+                future.set_result((state_metadata, state_tensors))
+            except BaseException as e:
+                future.set_exception(e)
+                logger.warning(e)
+                continue
+
+    def load_state_from_peers(self, wait=True) -> Optional[Any]:
+        """ Try to download the latest optimizer state one of the existing peer """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        return future.result() if wait else future
+
+    async def _load_state_from_peers(self, future: MPFuture):
+        key_manager = self._matchmaking.group_key_manager
+        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
+        peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
+                         if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
+
+        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
+            logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
+            future.set_result(None)
+            return
+
+        metadata = None
+        for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
+            if peer != self.endpoint:
+                logger.info(f"Downloading parameters from peer {peer}")
+                stream = None
+                try:
+                    leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                    stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                    current_tensor_parts, tensors = [], []
+                    async for message in stream:
+                        if message.metadata:
+                            metadata = PickleSerializer.loads(message.metadata)
+                        if message.tensor_part.dtype and current_tensor_parts:
+                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+                            tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                            current_tensor_parts = []
+                        current_tensor_parts.append(message.tensor_part)
+                    if current_tensor_parts:
+                        tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                    future.set_result((metadata, tensors))
+                    self.last_updated = get_dht_time()
+                    return
+                except grpc.aio.AioRpcError as e:
+                    logger.info(f"Failed to download state from {peer} - {e}")
+                finally:
+                    if stream is not None:
+                        await stream.code()
+
+        else:
+            logger.warning("Averager could not load state from peers: found no active peers.")
+            future.set_result(None)
+
+    def get_group_bits(self, wait: bool = True):
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_group_bits', [], dict(future=_future)))
+        return future.result() if wait else future
+
+    async def _get_group_bits(self, future: MPFuture):
+        future.set_result(self._matchmaking.group_key_manager.group_bits)
+
+    def set_group_bits(self, group_bits: str, wait: bool = True):
+        future, _future = MPFuture.make_pair()
+        assert all(bit in '01' for bit in group_bits)
+        self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        return future.result() if wait else future
+
+    async def _set_group_bits(self, group_bits: str, future: MPFuture):
+        try:
+            self._matchmaking.group_key_manager.group_bits = group_bits
+            return future.set_result(None)
+        except Exception as e:
+            if not future.done():
+                future.set_exception(e)
+
 
 
 def is_power_of_two(n):
 def is_power_of_two(n):
     """ Check whether n is a power of 2 """
     """ Check whether n is a power of 2 """

+ 15 - 11
hivemind/client/averaging/key_manager.py

@@ -27,16 +27,17 @@ class GroupKeyManager:
 
 
     def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
     def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
                  target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
                  target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
-                 nbits_expiration: float = 60):
+                 nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         if initial_group_bits is None:
         if initial_group_bits is None:
             search_result = dht.get(f"{prefix}.0b", latest=True)
             search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_bits = self.get_suggested_nbits(search_result) or ''
+            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
+            initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
         self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration = nbits_expiration
+        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
         self.suggested_nbits: Optional[int] = None
         self.suggested_nbits: Optional[int] = None
 
 
     @property
     @property
@@ -80,7 +81,8 @@ class GroupKeyManager:
         num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
         num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
 
 
         suggested_nbits = self.get_suggested_nbits(result)
         suggested_nbits = self.get_suggested_nbits(result)
-        if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
+        if suggested_nbits is not None and suggested_nbits != len(self.group_bits) and \
+                suggested_nbits != self.suggested_nbits:
             self.suggested_nbits = suggested_nbits
             self.suggested_nbits = suggested_nbits
             logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
             logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
         elif num_active_averagers >= self.excessive_size:
@@ -108,11 +110,11 @@ class GroupKeyManager:
         generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
         generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
         new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
-        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
+        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
 
 
         if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
         if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers_on_success())
+            asyncio.create_task(self.notify_stragglers())
         if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
         if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
             num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
             num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
             self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
             self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
@@ -122,12 +124,12 @@ class GroupKeyManager:
     async def update_key_on_not_enough_peers(self):
     async def update_key_on_not_enough_peers(self):
         """ this function is triggered whenever averager fails to assemble group within timeout """
         """ this function is triggered whenever averager fails to assemble group within timeout """
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
+        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ''
         if self.group_bits != prev_nbits:
         if self.group_bits != prev_nbits:
             logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
             logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
         self.suggested_nbits = None
         self.suggested_nbits = None
 
 
-    async def notify_stragglers_on_success(self):
+    async def notify_stragglers(self):
         """ Find averagers that have fewer nbits and redirect them to your current nbits """
         """ Find averagers that have fewer nbits and redirect them to your current nbits """
         for nbits in reversed(range(1, len(self.group_bits) - 1)):
         for nbits in reversed(range(1, len(self.group_bits) - 1)):
             preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
             preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
@@ -137,6 +139,8 @@ class GroupKeyManager:
                 await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
                 await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
                 break
                 break
 
 
-        root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
-        if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
-            await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
+        if isinstance(root_data, dict) and root_data.get(
+                self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
+            return
+        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)

+ 1 - 0
hivemind/dht/__init__.py

@@ -48,6 +48,7 @@ def is_valid_prefix(maybe_prefix: str) -> bool:
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
 
+
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)

+ 2 - 2
hivemind/dht/node.py

@@ -75,7 +75,7 @@ class DHTNode:
     async def create(
     async def create(
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
-            wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
+            wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             blacklist_time: float = 5.0, backoff_rate: float = 2.0,
             blacklist_time: float = 5.0, backoff_rate: float = 2.0,
@@ -155,7 +155,7 @@ class DHTNode:
                     straggler.cancel()
                     straggler.cancel()
                 finished_pings |= finished_in_time
                 finished_pings |= finished_in_time
 
 
-            if not finished_pings:
+            if not finished_pings or all(ping.result() is None for ping in finished_pings):
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
 
             if strict:
             if strict:

+ 2 - 1
hivemind/dht/protocol.py

@@ -110,7 +110,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if responded and validate:
         if responded and validate:
             try:
             try:
                 if self.server is not None and not response.available:
                 if self.server is not None and not response.available:
-                    raise ValidationError(f"peer {peer} couldn't access this node at {response.sender_endpoint} .")
+                    raise ValidationError(f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
+                                          f"Make sure that this port is open for incoming requests.")
 
 
                 if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
                 if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
                     if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
                     if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \

+ 8 - 0
hivemind/proto/averaging.proto

@@ -6,6 +6,7 @@ import "runtime.proto";
 service DecentralizedAveraging {
 service DecentralizedAveraging {
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
   rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
+  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
 }
 }
 
 
 enum MessageCode {
 enum MessageCode {
@@ -53,3 +54,10 @@ message AveragingData {
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
 }
 }
+
+message DownloadRequest {}
+
+message DownloadData {
+  bytes metadata = 1;
+  Tensor tensor_part = 2;
+}

+ 2 - 1
hivemind/utils/mpfuture.py

@@ -161,4 +161,5 @@ class MPFuture(base.Future):
 
 
     def __del__(self):
     def __del__(self):
         self._shutdown_trigger.set_result(True)
         self._shutdown_trigger.set_result(True)
-        self.connection.close()
+        if hasattr(self, 'connection'):
+            self.connection.close()

+ 60 - 0
tests/test_averaging.py

@@ -261,3 +261,63 @@ def test_overcrowded():
     for t in range(5):
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    num_calls = 0
+    super_metadata = dict(x=123)
+    super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
+
+    class TestAverager(hivemind.DecentralizedAverager):
+        def get_current_state(self):
+            """
+            Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+            :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+            """
+            nonlocal num_calls, super_metadata, super_tensors
+            num_calls += 1
+            return super_metadata, super_tensors
+
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
+    dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    averager1 = TestAverager([torch.randn(3), torch.rand(5)],
+                             dht=dht1, start=True,
+                             prefix='demo-run', target_group_size=2)
+
+    dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    dht2.get('demo-run.all_averagers')
+    averager2 = TestAverager([torch.randn(3), torch.rand(5)],
+                             dht=dht2, start=True,
+                             prefix='demo-run', target_group_size=2)
+
+    assert num_calls == 0
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 1
+    assert got_metadata == super_metadata
+    assert all(map(torch.allclose, got_tensors, super_tensors))
+
+    super_metadata['y'] = 123
+    super_tensors[1][2] = 9
+    assert num_calls == 1
+    assert got_metadata != super_metadata
+    assert not all(map(torch.allclose, got_tensors, super_tensors))
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 2
+    assert got_metadata == super_metadata
+    assert all(map(torch.allclose, got_tensors, super_tensors))
+
+    # check that normal averaging still works
+    futures = [averager.step(wait=False) for averager in [averager1, averager2]]
+    for future in futures:
+        future.result()
+
+
+@pytest.mark.forked
+def test_getset_bits():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
+                                              prefix='test_prefix', target_group_size=2)
+    averager.set_group_bits('00101011101010')
+    assert averager.get_group_bits() == '00101011101010'