فهرست منبع

Apply averager updates asynchronously (#395)

The baseline averager would first compute all updates and hold them in memory, then apply them to self.averaged_tensors.

The new version applies updates immediately upon receiving them.

This does not affect the rest of the code because asynchronous updates can only be received after you've sent the corresponding part to the averager.

This saves ~400mb ram in sahajbert2

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 3 سال پیش
والد
کامیت
54cdc3925a
1فایلهای تغییر یافته به همراه6 افزوده شده و 7 حذف شده
  1. 6 7
      hivemind/averaging/averager.py

+ 6 - 7
hivemind/averaging/averager.py

@@ -31,7 +31,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -443,15 +443,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
-
-                    # actually run all-reduce
-                    averaging_outputs = [output async for output in allreduce]
-
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        assert len(local_tensors) == len(self._averaged_tensors)
-                        for tensor, update in zip(local_tensors, averaging_outputs):
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                         self.last_updated = get_dht_time()
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e: