Forráskód Böngészése

Load state from peers in DecentralizedAverager (#154)

* Implemented DecentralizedAverager.load_state_from_peers that attempts to load the training state from another averager
  * The donor averager is chosen in the order from latest successfully updated to earliest
  * The definition of state can be extended by the user (by inheriting from DecentralizedAverager)
* calling __del__ on a partially created MPFuture will no longer cause error (edge case from albert)
image
* DecentralizedAverager now supports manually getting/setting current group bits (used as dht key)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 éve
szülő
commit
1d1252c30d

+ 0 - 2
docs/user/quickstart.md

@@ -16,8 +16,6 @@ pip install .
 
 You can also install it in the editable mode with `pip install -e .`.
 
-__Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
-
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __OS support:__ Linux and macOS should [just work](https://github.com/learning-at-home/hivemind/issues).
 We do not officially support Windows, but you are welcome to contribute your windows build :)

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.28'
+__version__ = '0.8.29'

+ 147 - 7
hivemind/client/averaging/__init__.py

@@ -14,22 +14,25 @@ import torch
 import numpy as np
 
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
+from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
+    serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
+from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.serializer import PickleSerializer, MSGPackSerializer
+from hivemind.utils import Endpoint, Port, MPFuture, get_logger
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-
 DataForGather = Any
 logger = get_logger(__name__)
+DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
-    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
 
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
@@ -103,6 +106,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
+        self.last_updated: DHTExpiration = -float('inf')
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
@@ -122,6 +126,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         if start:
             self.run_in_background(await_ready=True)
+            hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
 
     @property
     def port(self) -> Optional[Port]:
@@ -157,6 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self._pending_group_assembled.set()
             await server.start()
             self.ready.set()
+            asyncio.create_task(self._declare_for_download_periodically())
 
             while True:
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
@@ -195,10 +201,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, gather=gather, allow_retries=allow_retries, timeout=timeout)))
+        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
+        self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
+                                          allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
 
-    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
+    async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
         loop = asyncio.get_event_loop()
         start_time = get_dht_time()
         group_id = None
@@ -206,7 +214,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         while not future.done():
             try:
                 self._pending_group_assembled.clear()
-                gather_binary = self.serializer.dumps(gather)
                 allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                 if allreduce_group is None:
                     raise AllreduceException("Averaging step failed: could not find a group.")
@@ -245,6 +252,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             assert len(local_tensors) == len(self._averaged_tensors)
             for tensor, update in zip(local_tensors, averaging_deltas):
                 tensor.add_(update, alpha=self._averaging_alpha)
+        self.last_updated = get_dht_time()
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -255,6 +263,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
+        self.last_updated = get_dht_time()
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -279,6 +288,137 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
             yield message
 
+    async def _declare_for_download_periodically(self):
+        download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
+        while True:
+            asyncio.create_task(asyncio.wait_for(self.dht.store(
+                download_key, subkey=self.endpoint, value=self.last_updated,
+                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                timeout=self._matchmaking.averaging_expiration))
+            await asyncio.sleep(self._matchmaking.averaging_expiration)
+
+    async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.DownloadData]:
+        """
+        Get the up-to-date trainer state from a peer.
+        The state consists of two parts: (metadata, tensors)
+
+         - metadata is a small pickle-serialized entry meant to store scalars and hyperparameters
+         - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
+        """
+        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
+        metadata, tensors = await self._get_current_state_from_host_process()
+
+        for tensor in tensors:
+            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+                if metadata is not None:
+                    yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
+                    metadata = None
+                else:
+                    yield averaging_pb2.DownloadData(tensor_part=part)
+
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+        """
+        Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with self.get_tensors() as tensors:
+            return dict(group_key=self.get_group_bits()), tensors
+
+    async def _get_current_state_from_host_process(self):
+        """ Executed in the averager process inside rpc_download_state """
+        future, _future = MPFuture.make_pair()
+        self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        return await future
+
+    def _background_thread_fetch_current_state_if_asked(self):
+        """ Executed in the host process as a background thread. """
+        while True:
+            trigger, future = self.pipe.recv()
+            assert trigger == '_TRIGGER_GET_CURRENT_STATE'
+            try:
+                state_metadata, state_tensors = self.get_current_state()
+                # note: serialize here to avoid initializing cuda in the guest process
+                state_metadata = PickleSerializer.dumps(state_metadata)
+                state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
+                                      for tensor in state_tensors)
+                future.set_result((state_metadata, state_tensors))
+            except BaseException as e:
+                future.set_exception(e)
+                logger.warning(e)
+                continue
+
+    def load_state_from_peers(self, wait=True) -> Optional[Any]:
+        """ Try to download the latest optimizer state one of the existing peer """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        return future.result() if wait else future
+
+    async def _load_state_from_peers(self, future: MPFuture):
+        key_manager = self._matchmaking.group_key_manager
+        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
+        peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
+                         if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
+
+        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
+            logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
+            future.set_result(None)
+            return
+
+        metadata = None
+        for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
+            if peer != self.endpoint:
+                logger.info(f"Downloading parameters from peer {peer}")
+                stream = None
+                try:
+                    leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                    stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                    current_tensor_parts, tensors = [], []
+                    async for message in stream:
+                        if message.metadata:
+                            metadata = PickleSerializer.loads(message.metadata)
+                        if message.tensor_part.dtype and current_tensor_parts:
+                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+                            tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                            current_tensor_parts = []
+                        current_tensor_parts.append(message.tensor_part)
+                    if current_tensor_parts:
+                        tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                    future.set_result((metadata, tensors))
+                    self.last_updated = get_dht_time()
+                    return
+                except grpc.aio.AioRpcError as e:
+                    logger.info(f"Failed to download state from {peer} - {e}")
+                finally:
+                    if stream is not None:
+                        await stream.code()
+
+        else:
+            logger.warning("Averager could not load state from peers: found no active peers.")
+            future.set_result(None)
+
+    def get_group_bits(self, wait: bool = True):
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_group_bits', [], dict(future=_future)))
+        return future.result() if wait else future
+
+    async def _get_group_bits(self, future: MPFuture):
+        future.set_result(self._matchmaking.group_key_manager.group_bits)
+
+    def set_group_bits(self, group_bits: str, wait: bool = True):
+        future, _future = MPFuture.make_pair()
+        assert all(bit in '01' for bit in group_bits)
+        self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        return future.result() if wait else future
+
+    async def _set_group_bits(self, group_bits: str, future: MPFuture):
+        try:
+            self._matchmaking.group_key_manager.group_bits = group_bits
+            return future.set_result(None)
+        except Exception as e:
+            if not future.done():
+                future.set_exception(e)
+
 
 def is_power_of_two(n):
     """ Check whether n is a power of 2 """

+ 15 - 11
hivemind/client/averaging/key_manager.py

@@ -27,16 +27,17 @@ class GroupKeyManager:
 
     def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
                  target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
-                 nbits_expiration: float = 60):
+                 nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         if initial_group_bits is None:
             search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_bits = self.get_suggested_nbits(search_result) or ''
+            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
+            initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration = nbits_expiration
+        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
         self.suggested_nbits: Optional[int] = None
 
     @property
@@ -80,7 +81,8 @@ class GroupKeyManager:
         num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
 
         suggested_nbits = self.get_suggested_nbits(result)
-        if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
+        if suggested_nbits is not None and suggested_nbits != len(self.group_bits) and \
+                suggested_nbits != self.suggested_nbits:
             self.suggested_nbits = suggested_nbits
             logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
@@ -108,11 +110,11 @@ class GroupKeyManager:
         generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
-        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
+        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
 
         if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers_on_success())
+            asyncio.create_task(self.notify_stragglers())
         if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
             num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
             self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
@@ -122,12 +124,12 @@ class GroupKeyManager:
     async def update_key_on_not_enough_peers(self):
         """ this function is triggered whenever averager fails to assemble group within timeout """
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
+        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ''
         if self.group_bits != prev_nbits:
             logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
         self.suggested_nbits = None
 
-    async def notify_stragglers_on_success(self):
+    async def notify_stragglers(self):
         """ Find averagers that have fewer nbits and redirect them to your current nbits """
         for nbits in reversed(range(1, len(self.group_bits) - 1)):
             preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
@@ -137,6 +139,8 @@ class GroupKeyManager:
                 await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
                 break
 
-        root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
-        if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
-            await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
+        if isinstance(root_data, dict) and root_data.get(
+                self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
+            return
+        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)

+ 1 - 0
hivemind/dht/__init__.py

@@ -48,6 +48,7 @@ def is_valid_prefix(maybe_prefix: str) -> bool:
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
+
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)

+ 2 - 2
hivemind/dht/node.py

@@ -75,7 +75,7 @@ class DHTNode:
     async def create(
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
-            wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
+            wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             blacklist_time: float = 5.0, backoff_rate: float = 2.0,
@@ -155,7 +155,7 @@ class DHTNode:
                     straggler.cancel()
                 finished_pings |= finished_in_time
 
-            if not finished_pings:
+            if not finished_pings or all(ping.result() is None for ping in finished_pings):
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
             if strict:

+ 2 - 1
hivemind/dht/protocol.py

@@ -110,7 +110,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if responded and validate:
             try:
                 if self.server is not None and not response.available:
-                    raise ValidationError(f"peer {peer} couldn't access this node at {response.sender_endpoint} .")
+                    raise ValidationError(f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
+                                          f"Make sure that this port is open for incoming requests.")
 
                 if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
                     if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \

+ 8 - 0
hivemind/proto/averaging.proto

@@ -6,6 +6,7 @@ import "runtime.proto";
 service DecentralizedAveraging {
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
+  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
 }
 
 enum MessageCode {
@@ -53,3 +54,10 @@ message AveragingData {
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
 }
+
+message DownloadRequest {}
+
+message DownloadData {
+  bytes metadata = 1;
+  Tensor tensor_part = 2;
+}

+ 2 - 1
hivemind/utils/mpfuture.py

@@ -161,4 +161,5 @@ class MPFuture(base.Future):
 
     def __del__(self):
         self._shutdown_trigger.set_result(True)
-        self.connection.close()
+        if hasattr(self, 'connection'):
+            self.connection.close()

+ 60 - 0
tests/test_averaging.py

@@ -261,3 +261,63 @@ def test_overcrowded():
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    num_calls = 0
+    super_metadata = dict(x=123)
+    super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
+
+    class TestAverager(hivemind.DecentralizedAverager):
+        def get_current_state(self):
+            """
+            Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+            :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+            """
+            nonlocal num_calls, super_metadata, super_tensors
+            num_calls += 1
+            return super_metadata, super_tensors
+
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
+    dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    averager1 = TestAverager([torch.randn(3), torch.rand(5)],
+                             dht=dht1, start=True,
+                             prefix='demo-run', target_group_size=2)
+
+    dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    dht2.get('demo-run.all_averagers')
+    averager2 = TestAverager([torch.randn(3), torch.rand(5)],
+                             dht=dht2, start=True,
+                             prefix='demo-run', target_group_size=2)
+
+    assert num_calls == 0
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 1
+    assert got_metadata == super_metadata
+    assert all(map(torch.allclose, got_tensors, super_tensors))
+
+    super_metadata['y'] = 123
+    super_tensors[1][2] = 9
+    assert num_calls == 1
+    assert got_metadata != super_metadata
+    assert not all(map(torch.allclose, got_tensors, super_tensors))
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 2
+    assert got_metadata == super_metadata
+    assert all(map(torch.allclose, got_tensors, super_tensors))
+
+    # check that normal averaging still works
+    futures = [averager.step(wait=False) for averager in [averager1, averager2]]
+    for future in futures:
+        future.result()
+
+
+@pytest.mark.forked
+def test_getset_bits():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
+                                              prefix='test_prefix', target_group_size=2)
+    averager.set_group_bits('00101011101010')
+    assert averager.get_group_bits() == '00101011101010'