|
@@ -32,7 +32,7 @@ from hivemind.dht import DHT, DHTID
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
from hivemind.proto import averaging_pb2
|
|
from hivemind.proto import averaging_pb2
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop, azip
|
|
|
|
|
|
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -451,11 +451,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
)
|
|
)
|
|
|
|
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
- assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
|
- async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
|
- # note: all-reduce is performed asynchronously when iterating
|
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
- self.last_updateod = get_dht_time()
|
|
|
|
|
|
+ if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
|
+ async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
|
+ # all-reduce is performed asynchronously while iterating
|
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
|
+ else:
|
|
|
|
+ async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
return allreduce.gathered
|
|
return allreduce.gathered
|
|
except BaseException as e:
|
|
except BaseException as e:
|