Explorar el Código

Support auxiliary participants in AllReduceProtocol (#260)

* DecentralizedAverager: support auxiliary peers that assist in allreduce without sending their own data
* implement a flag that disables state sharing in averager
* more natural parameterization of batch_size vs batch_size_per_step
* update test_allreduce_once for new aux peers
foksly hace 4 años
padre
commit
e58f65db33

+ 55 - 13
hivemind/client/averaging/__init__.py

@@ -20,7 +20,7 @@ import torch
 import numpy as np
 import numpy as np
 
 
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.group_info import GroupInfo
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
     :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param auxiliary: if this flag is specified, averager.step will only assist others without sending
+          local tensors for averaging
+    :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
+      with averager.allow_state_sharing = True / False
 
 
     Example:
     Example:
 
 
@@ -94,6 +98,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
+                 auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
+        if not self.listen:
+            self.mode = AveragingMode.CLIENT
+        elif auxiliary:
+            self.mode = AveragingMode.AUX
+        else:
+            self.mode = AveragingMode.NODE
+
         self.channel_options = channel_options
         self.channel_options = channel_options
         self.daemon = daemon
         self.daemon = daemon
 
 
@@ -129,6 +142,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+
+        self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+
         self._averager_endpoint: Optional[Endpoint] = None
         self._averager_endpoint: Optional[Endpoint] = None
         if not self.listen:
         if not self.listen:
             self._averager_endpoint = f'client::{uuid.uuid4()}'
             self._averager_endpoint = f'client::{uuid.uuid4()}'
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def port(self) -> Optional[Port]:
     def port(self) -> Optional[Port]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
+    @property
+    def allow_state_sharing(self) -> bool:
+        """ if set to True, other peers can download this peer's state """
+        return bool(self._allow_state_sharing.value)
+
+    @allow_state_sharing.setter
+    def allow_state_sharing(self, value: bool):
+        if value is True and not self.listen:
+            logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
+        else:
+            self._allow_state_sharing.value = value
+
     @property
     @property
     def endpoint(self) -> Optional[Endpoint]:
     def endpoint(self) -> Optional[Endpoint]:
         if self.listen and self._averager_endpoint is None:
         if self.listen and self._averager_endpoint is None:
@@ -236,7 +265,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
-        assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if self.mode == AveragingMode.AUX and weight != 1:
+            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        else:
+            assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
@@ -253,7 +286,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
@@ -263,7 +296,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     self._running_groups[group_id] = allreduce_runner
                     self._running_groups[group_id] = allreduce_runner
                     self._pending_group_assembled.set()
                     self._pending_group_assembled.set()
                     await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
                     await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+                    if self.mode != AveragingMode.AUX:
+                        await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
 
                     # averaging is finished, exit the loop
                     # averaging is finished, exit the loop
                     future.set_result(allreduce_runner.gathered)
                     future.set_result(allreduce_runner.gathered)
@@ -293,19 +327,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         try:
         try:
-            weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
-
             # compute optimal part sizes from peer throughputs
             # compute optimal part sizes from peer throughputs
-            incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
+            modes = tuple(map(AveragingMode, mode_ids))
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)]  # TODO: replace with proper load balancing
             part_sizes = await asyncio.get_event_loop().run_in_executor(
             part_sizes = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
             async with self.get_tensors_async() as averaged_tensors:
             async with self.get_tensors_async() as averaged_tensors:
                 return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
                 return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
                                        ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
                                        ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
+                                       weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
         except Exception as e:
         except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
+            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")
 
 
     def update_tensors(self, allreduce_group: AllReduceRunner):
     def update_tensors(self, allreduce_group: AllReduceRunner):
         """
         """
@@ -366,10 +400,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _declare_for_download_periodically(self):
     async def _declare_for_download_periodically(self):
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         while True:
         while True:
-            asyncio.create_task(asyncio.wait_for(self.dht.store(
-                download_key, subkey=self.endpoint, value=self.last_updated,
-                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
-                timeout=self._matchmaking.averaging_expiration))
+            if self.allow_state_sharing:
+                asyncio.create_task(asyncio.wait_for(self.dht.store(
+                    download_key, subkey=self.endpoint, value=self.last_updated,
+                    expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                    timeout=self._matchmaking.averaging_expiration))
             await asyncio.sleep(self._matchmaking.averaging_expiration)
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
 
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
@@ -381,6 +416,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
         """
+        if not self.allow_state_sharing:
+            return  # deny request and direct peer to the next prospective averager
         chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
         chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
         metadata, tensors = await self._get_current_state_from_host_process()
         metadata, tensors = await self._get_current_state_from_host_process()
 
 
@@ -452,6 +489,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                             current_tensor_parts.append(message.tensor_part)
                             current_tensor_parts.append(message.tensor_part)
                         if current_tensor_parts:
                         if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+
+                        if not metadata:
+                            logger.debug(f"Peer {peer} did not send its state.")
+                            continue
+
                         logger.info(f"Finished downloading state from {peer}")
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         self.last_updated = get_dht_time()

+ 40 - 16
hivemind/client/averaging/allreduce.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
+from enum import Enum
 
 
 import grpc
 import grpc
 import torch
 import torch
@@ -14,6 +15,12 @@ GroupID = bytes
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+class AveragingMode(Enum):
+    NODE = 0
+    CLIENT = 1
+    AUX = 2
+
+
 class AllReduceProtocol:
 class AllReduceProtocol:
     """
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
     An internal class that runs butterfly AllReduce in a predefined group of averagers
@@ -27,12 +34,16 @@ class AllReduceProtocol:
     """
     """
 
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
+                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
+                 modes: Optional[Sequence[AveragingMode]] = None):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint = group_id, endpoint
         self.group_id, self.endpoint = group_id, endpoint
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
-                                      if part_size == 0}
+        if modes is None:
+            modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
+        self.peer_modes = dict(zip(ordered_group_endpoints, modes))
+
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
         self.return_deltas = return_deltas
@@ -43,8 +54,14 @@ class AllReduceProtocol:
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-        for endpoint in self.client_mode_endpoints:
-            self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        
+        self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
+
+        if self.num_senders == 0:
+            self.future.set_result(None)
+        for endpoint, mode in self.peer_modes.items():
+            if mode == AveragingMode.CLIENT:
+                self.averaged_tensor_parts[endpoint] = torch.tensor([])
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
@@ -65,20 +82,24 @@ class AllReduceProtocol:
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
+        assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
         assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
         assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-        logger.debug(f"{self} - accumulating tensor part from {source}")
 
 
+        logger.debug(f"{self} - accumulating tensor part from {source}")
         self.accumulator.add_(remote_part, alpha=weight)
         self.accumulator.add_(remote_part, alpha=weight)
         self.denominator += weight
         self.denominator += weight
         self.accumulated_from.add(source)
         self.accumulated_from.add(source)
 
 
-        assert len(self.accumulated_from) <= self.group_size
-        if len(self.accumulated_from) == len(self.local_tensor_parts):
+        assert len(self.accumulated_from) <= self.num_senders
+        if len(self.accumulated_from) == self.num_senders:
             average_result = self.accumulator.div_(self.denominator)
             average_result = self.accumulator.div_(self.denominator)
-            self.register_averaged_part(self.endpoint, average_result)
             self.averaged_part.set_result(average_result)
             self.averaged_part.set_result(average_result)
 
 
+            if self.peer_modes[self.endpoint] == AveragingMode.AUX:
+                self.future.set_result(None)  # auxiliary mode has finished averaging
+            else:
+                self.register_averaged_part(self.endpoint, average_result)
+
         return await self.averaged_part
         return await self.averaged_part
 
 
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
@@ -87,6 +108,7 @@ class AllReduceProtocol:
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
         assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
         assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
+        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
         logger.debug(f"{self} - receiving averaged tensor part from {source}")
         logger.debug(f"{self} - receiving averaged tensor part from {source}")
         self.averaged_tensor_parts[source] = averaged_part
         self.averaged_tensor_parts[source] = averaged_part
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
@@ -133,9 +155,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
                  chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
                  chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False):
+                 gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
+                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
         self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
 
 
@@ -144,6 +166,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
 
     async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
     async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
         """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
         """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
         if peer_endpoint == self.endpoint:
         if peer_endpoint == self.endpoint:
             return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
             return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
@@ -182,9 +205,10 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         """
         try:
         try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                         for i, peer in enumerate(self.ordered_group_endpoints)
-                                         if peer not in self.client_mode_endpoints))
+            if self.peer_modes[self.endpoint] != AveragingMode.AUX:
+                await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
+                                            for i, peer in enumerate(self.ordered_group_endpoints)
+                                            if self.peer_modes[peer] != AveragingMode.CLIENT))
             return await self
             return await self
         except BaseException as e:
         except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -391,7 +391,7 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
                 self.update_triggered.set()
 
 
-            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint):
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                                    return_when=asyncio.FIRST_COMPLETED)
                                    return_when=asyncio.FIRST_COMPLETED)
                 self.declared_expiration.clear()
                 self.declared_expiration.clear()

+ 1 - 1
hivemind/optim/collaborative.py

@@ -191,7 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=self.batch_size_per_step)
+            self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
         if not self.collaboration_state.ready_for_step:
         if not self.collaboration_state.ready_for_step:

+ 36 - 15
tests/test_averaging.py

@@ -5,7 +5,7 @@ import numpy as np
 import torch
 import torch
 import pytest
 import pytest
 import hivemind
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.utils import Endpoint
 from hivemind.utils import Endpoint
@@ -41,26 +41,26 @@ async def test_key_manager():
     assert len(q5) == 0
     assert len(q5) == 0
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
-def test_allreduce_once(n_client_mode_peers):
+def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
 
     n_peers = 4
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
-
+    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    random.shuffle(modes)
+    
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
-
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
+    peer_tensors = [tensors1, tensors2, tensors3, tensors4]
+    
+    reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
+                 if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
-                                                start=True)
-                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+                                                prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
+                                                auxiliary=mode == AveragingMode.AUX, start=True)
+                 for tensors, mode in zip(peer_tensors, modes)]
 
 
     futures = []
     futures = []
     for averager in averagers:
     for averager in averagers:
@@ -71,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers):
             assert averager.endpoint in result
             assert averager.endpoint in result
 
 
     for averager in averagers:
     for averager in averagers:
-        with averager.get_tensors() as averaged_tensors:
-            for ref, our in zip(reference, averaged_tensors):
-                assert torch.allclose(ref, our, atol=1e-6)
+        if averager.mode != AveragingMode.AUX:
+            with averager.get_tensors() as averaged_tensors:
+                for ref, our in zip(reference, averaged_tensors):
+                    assert torch.allclose(ref, our, atol=1e-6)
 
 
     for averager in averagers:
     for averager in averagers:
         averager.shutdown()
         averager.shutdown()
     dht.shutdown()
     dht.shutdown()
 
 
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients", [0, 1, 2])
+@pytest.mark.parametrize("n_aux", [0, 1, 2])
+def test_allreduce_once(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
+def test_allreduce_once_edge_cases(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -369,6 +383,13 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
 
+    averager1.allow_state_sharing = False
+    assert averager2.load_state_from_peers() is None
+    averager1.allow_state_sharing = True
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 3
+    assert got_metadata == super_metadata
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_getset_bits():
 def test_getset_bits():