Просмотр исходного кода

Merge remote-tracking branch 'origin/master' into pre_scheduling_again

# Conflicts:
#	hivemind/averaging/averager.py
justheuristic 3 лет назад
Родитель
Сommit
36e5b02a0d
1 измененных файлов с 9 добавлено и 6 удалено
  1. 9 6
      hivemind/averaging/averager.py

+ 9 - 6
hivemind/averaging/averager.py

@@ -32,7 +32,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop, azip
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -451,11 +451,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
-                    assert len(local_tensors) == len(self._averaged_tensors)
-                    async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                        # note: all-reduce is performed asynchronously when iterating
-                        tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updateod = get_dht_time()
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e: