justheuristic há 3 anos atrás
pai
commit
ba05129992
1 ficheiros alterados com 28 adições e 26 exclusões
  1. 28 26
      hivemind/averaging/averager.py

+ 28 - 26
hivemind/averaging/averager.py

@@ -32,7 +32,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop, azip
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -344,26 +344,33 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
+        deadline = get_dht_time() + timeout if timeout else None
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
-        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
 
-        step = StepControl(scheduled_time, weight, wait_for_trigger=wait_for_trigger,
-                           gather_binary=gather_binary, timeout=timeout, allow_retries=allow_retries)
-        self._outer_pipe.send(("_step", [], dict(step=step)))
+        user_gather_bytes = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        gather_binary = self.serializer.dumps([self.bandwidth, self.mode.value, user_gather_bytes])
+        step = StepControl(scheduled_time=scheduled_time, deadline=deadline, allow_retries=allow_retries,
+                           weight=weight, gather_binary=gather_binary)
+
+        future_for_trigger = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
+        step.attach_trigger(future_for_trigger.result())
+
+        if not wait_for_trigger:
+            step.allow_allreduce()
         return step.result() if wait else step
 
-    async def _step(
-        self, *, step: StepControl, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]
-    ):
-        start_time = get_dht_time()
+    async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
+        trigger = MPFuture()
+        step.attach_trigger(trigger)
+        future_for_trigger.set_result(trigger)
 
         try:
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
-
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     group_info = await self._matchmaking.look_for_group(step)
                     if group_info is None:
@@ -394,12 +401,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                 ) as e:
-                    time_elapsed = get_dht_time() - start_time
-                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                        logger.exception(f"Averager caught {repr(e)}")
+                    if not step.allow_retries or get_dht_time() >= step.deadline:
+                        logger.exception(f"{self.__class__.__name__} caught {repr(e)}")
                         step.set_exception(e)
                     else:
-                        logger.warning(f"Averager caught {repr(e)}, retrying")
+                        logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
         except BaseException as e:
             if not step.done():
@@ -418,8 +424,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -445,15 +451,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
-
-                    # actually run all-reduce
-                    averaging_outputs = [output async for output in allreduce]
-
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        assert len(local_tensors) == len(self._averaged_tensors)
-                        for tensor, update in zip(local_tensors, averaging_outputs):
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                    assert len(local_tensors) == len(self._averaged_tensors)
+                    async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                        # note: all-reduce is performed asynchronously when iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updateod = get_dht_time()
 
                 return allreduce.gathered
         except BaseException as e: