Procházet zdrojové kódy

Support auxiliary participants in AllReduceProtocol (#260)

* DecentralizedAverager: support auxiliary peers that assist in allreduce without sending their own data
* implement a flag that disables state sharing in averager
* more natural parameterization of batch_size vs batch_size_per_step
* update test_allreduce_once for new aux peers
foksly před 4 roky
rodič
revize
e58f65db33

+ 55 - 13
hivemind/client/averaging/__init__.py

@@ -20,7 +20,7 @@ import torch
 import numpy as np
 
 from hivemind.dht import DHT, DHTID
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.group_info import GroupInfo
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param auxiliary: if this flag is specified, averager.step will only assist others without sending
+          local tensors for averaging
+    :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
+      with averager.allow_state_sharing = True / False
 
     Example:
 
@@ -94,6 +98,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
+                 auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
         self.dht = dht
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
+        if not self.listen:
+            self.mode = AveragingMode.CLIENT
+        elif auxiliary:
+            self.mode = AveragingMode.AUX
+        else:
+            self.mode = AveragingMode.NODE
+
         self.channel_options = channel_options
         self.daemon = daemon
 
@@ -129,6 +142,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+
+        self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+
         self._averager_endpoint: Optional[Endpoint] = None
         if not self.listen:
             self._averager_endpoint = f'client::{uuid.uuid4()}'
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def port(self) -> Optional[Port]:
         return self._port.value if self._port.value != 0 else None
 
+    @property
+    def allow_state_sharing(self) -> bool:
+        """ if set to True, other peers can download this peer's state """
+        return bool(self._allow_state_sharing.value)
+
+    @allow_state_sharing.setter
+    def allow_state_sharing(self, value: bool):
+        if value is True and not self.listen:
+            logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
+        else:
+            self._allow_state_sharing.value = value
+
     @property
     def endpoint(self) -> Optional[Endpoint]:
         if self.listen and self._averager_endpoint is None:
@@ -236,7 +265,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
-        assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if self.mode == AveragingMode.AUX and weight != 1:
+            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        else:
+            assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
@@ -253,7 +286,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
@@ -263,7 +296,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     self._running_groups[group_id] = allreduce_runner
                     self._pending_group_assembled.set()
                     await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+                    if self.mode != AveragingMode.AUX:
+                        await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
                     # averaging is finished, exit the loop
                     future.set_result(allreduce_runner.gathered)
@@ -293,19 +327,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         try:
-            weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
-
             # compute optimal part sizes from peer throughputs
-            incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
+            modes = tuple(map(AveragingMode, mode_ids))
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)]  # TODO: replace with proper load balancing
             part_sizes = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
             async with self.get_tensors_async() as averaged_tensors:
                 return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
                                        ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
+                                       weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
         except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
+            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")
 
     def update_tensors(self, allreduce_group: AllReduceRunner):
         """
@@ -366,10 +400,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _declare_for_download_periodically(self):
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         while True:
-            asyncio.create_task(asyncio.wait_for(self.dht.store(
-                download_key, subkey=self.endpoint, value=self.last_updated,
-                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
-                timeout=self._matchmaking.averaging_expiration))
+            if self.allow_state_sharing:
+                asyncio.create_task(asyncio.wait_for(self.dht.store(
+                    download_key, subkey=self.endpoint, value=self.last_updated,
+                    expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                    timeout=self._matchmaking.averaging_expiration))
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
@@ -381,6 +416,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
+        if not self.allow_state_sharing:
+            return  # deny request and direct peer to the next prospective averager
         chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
         metadata, tensors = await self._get_current_state_from_host_process()
 
@@ -452,6 +489,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                             current_tensor_parts.append(message.tensor_part)
                         if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+
+                        if not metadata:
+                            logger.debug(f"Peer {peer} did not send its state.")
+                            continue
+
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()

+ 40 - 16
hivemind/client/averaging/allreduce.py

@@ -1,5 +1,6 @@
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
+from enum import Enum
 
 import grpc
 import torch
@@ -14,6 +15,12 @@ GroupID = bytes
 logger = get_logger(__name__)
 
 
+class AveragingMode(Enum):
+    NODE = 0
+    CLIENT = 1
+    AUX = 2
+
+
 class AllReduceProtocol:
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
@@ -27,12 +34,16 @@ class AllReduceProtocol:
     """
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
+                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
+                 modes: Optional[Sequence[AveragingMode]] = None):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint = group_id, endpoint
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
-                                      if part_size == 0}
+        if modes is None:
+            modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
+        self.peer_modes = dict(zip(ordered_group_endpoints, modes))
+
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
@@ -43,8 +54,14 @@ class AllReduceProtocol:
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-        for endpoint in self.client_mode_endpoints:
-            self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        
+        self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
+
+        if self.num_senders == 0:
+            self.future.set_result(None)
+        for endpoint, mode in self.peer_modes.items():
+            if mode == AveragingMode.CLIENT:
+                self.averaged_tensor_parts[endpoint] = torch.tensor([])
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
@@ -65,20 +82,24 @@ class AllReduceProtocol:
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
+        assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
         assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-        logger.debug(f"{self} - accumulating tensor part from {source}")
 
+        logger.debug(f"{self} - accumulating tensor part from {source}")
         self.accumulator.add_(remote_part, alpha=weight)
         self.denominator += weight
         self.accumulated_from.add(source)
 
-        assert len(self.accumulated_from) <= self.group_size
-        if len(self.accumulated_from) == len(self.local_tensor_parts):
+        assert len(self.accumulated_from) <= self.num_senders
+        if len(self.accumulated_from) == self.num_senders:
             average_result = self.accumulator.div_(self.denominator)
-            self.register_averaged_part(self.endpoint, average_result)
             self.averaged_part.set_result(average_result)
 
+            if self.peer_modes[self.endpoint] == AveragingMode.AUX:
+                self.future.set_result(None)  # auxiliary mode has finished averaging
+            else:
+                self.register_averaged_part(self.endpoint, average_result)
+
         return await self.averaged_part
 
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
@@ -87,6 +108,7 @@ class AllReduceProtocol:
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
         assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
+        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
         logger.debug(f"{self} - receiving averaged tensor part from {source}")
         self.averaged_tensor_parts[source] = averaged_part
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
@@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
                  chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False):
+                 gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
+                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
 
@@ -144,6 +166,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
     async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
         """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
         if peer_endpoint == self.endpoint:
             return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
@@ -182,9 +205,10 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                         for i, peer in enumerate(self.ordered_group_endpoints)
-                                         if peer not in self.client_mode_endpoints))
+            if self.peer_modes[self.endpoint] != AveragingMode.AUX:
+                await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
+                                            for i, peer in enumerate(self.ordered_group_endpoints)
+                                            if self.peer_modes[peer] != AveragingMode.CLIENT))
             return await self
         except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -391,7 +391,7 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint):
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                                    return_when=asyncio.FIRST_COMPLETED)
                 self.declared_expiration.clear()

+ 1 - 1
hivemind/optim/collaborative.py

@@ -191,7 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=self.batch_size_per_step)
+            self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:

+ 36 - 15
tests/test_averaging.py

@@ -5,7 +5,7 @@ import numpy as np
 import torch
 import pytest
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.utils import Endpoint
@@ -41,26 +41,26 @@ async def test_key_manager():
     assert len(q5) == 0
 
 
-@pytest.mark.forked
-@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
-def test_allreduce_once(n_client_mode_peers):
+def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
-
+    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    random.shuffle(modes)
+    
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
-
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
+    peer_tensors = [tensors1, tensors2, tensors3, tensors4]
+    
+    reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
+                 if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
-                                                start=True)
-                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+                                                prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
+                                                auxiliary=mode == AveragingMode.AUX, start=True)
+                 for tensors, mode in zip(peer_tensors, modes)]
 
     futures = []
     for averager in averagers:
@@ -71,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers):
             assert averager.endpoint in result
 
     for averager in averagers:
-        with averager.get_tensors() as averaged_tensors:
-            for ref, our in zip(reference, averaged_tensors):
-                assert torch.allclose(ref, our, atol=1e-6)
+        if averager.mode != AveragingMode.AUX:
+            with averager.get_tensors() as averaged_tensors:
+                for ref, our in zip(reference, averaged_tensors):
+                    assert torch.allclose(ref, our, atol=1e-6)
 
     for averager in averagers:
         averager.shutdown()
     dht.shutdown()
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients", [0, 1, 2])
+@pytest.mark.parametrize("n_aux", [0, 1, 2])
+def test_allreduce_once(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
+def test_allreduce_once_edge_cases(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -369,6 +383,13 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
+    averager1.allow_state_sharing = False
+    assert averager2.load_state_from_peers() is None
+    averager1.allow_state_sharing = True
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 3
+    assert got_metadata == super_metadata
+
 
 @pytest.mark.forked
 def test_getset_bits():