|
@@ -31,7 +31,7 @@ from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
|
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -443,15 +443,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
)
|
|
|
|
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
-
|
|
|
- # actually run all-reduce
|
|
|
- averaging_outputs = [output async for output in allreduce]
|
|
|
-
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
- for tensor, update in zip(local_tensors, averaging_outputs):
|
|
|
+ async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
+ # all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
self.last_updated = get_dht_time()
|
|
|
+ else:
|
|
|
+ async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
return allreduce.gathered
|
|
|
except BaseException as e:
|