Przeglądaj źródła

Merge remote-tracking branch 'origin/master' into hivemind_optimizer_thirdtimesthecharm

justheuristic 3 lat temu
rodzic
commit
f69b32a775

+ 23 - 11
hivemind/averaging/averager.py

@@ -377,25 +377,35 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             data_for_gather=data_for_gather,
         )
 
-        future_for_trigger = MPFuture()
-        self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
-        step.attach_trigger(future_for_trigger.result())
+        future_for_init = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
+        step.attach(*future_for_init.result())
 
         if not require_trigger:
             step.allow_allreduce()
         return step.result() if wait else step
 
-    async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
+    async def _step(self, *, step: StepControl, future_for_init: MPFuture):
         try:
-            trigger = MPFuture()
-            step.attach_trigger(trigger)
-            future_for_trigger.set_result(trigger)
+            trigger, cancel = MPFuture(), MPFuture()
+            step.attach(trigger, cancel)
+            future_for_init.set_result((trigger, cancel))
 
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    group_info = await self._matchmaking.look_for_group(step)
+                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+                    check_cancel_task = asyncio.create_task(step.wait_for_cancel())
+
+                    await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+                    if step.cancelled():
+                        matchmaking_task.cancel()
+                        raise asyncio.CancelledError()
+                    else:
+                        check_cancel_task.cancel()
+
+                    group_info = await matchmaking_task
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
@@ -424,9 +434,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                 ) as e:
-                    if not step.allow_retries or get_dht_time() >= step.deadline:
-                        logger.exception(e)
-                        step.set_exception(e)
+                    if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
+                        if not step.cancelled():
+                            logger.exception(e)
+                        if not step.done():
+                            step.set_exception(e)
                     else:
                         logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 

+ 14 - 6
hivemind/averaging/control.py

@@ -43,6 +43,7 @@ class StepControl(MPFuture):
         super().__init__()
         self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
         self._trigger: Optional[MPFuture] = None
+        self._cancel: Optional[MPFuture] = None
 
         # Buffer contents:
         # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
@@ -52,12 +53,12 @@ class StepControl(MPFuture):
         self.weight = weight
         self.began_allreduce = False
 
-    def attach_trigger(self, trigger: MPFuture):
-        assert self._trigger is None, "Trigger is already attached"
-        self._trigger = trigger
+    def attach(self, trigger: MPFuture, cancel: MPFuture):
+        assert self._trigger is None and self._cancel is None, "Futures are already attached"
+        self._trigger, self._cancel = trigger, cancel
 
     def allow_allreduce(self):
-        """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
+        """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
         assert self._trigger is not None, "StepControl does not have an attached trigger"
         if self._trigger.done():
             logger.warning("Trigger is already set")
@@ -133,16 +134,23 @@ class StepControl(MPFuture):
         return dict(
             super().__getstate__(),
             _trigger=self._trigger,
+            _cancel=self._cancel,
             _shared_buffer=self._shared_buffer,
             immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
         )
 
     def __setstate__(self, state):
         super().__setstate__(state)
-        self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
+        self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
 
     def cancel(self) -> bool:
         if self._trigger is not None:
             self._trigger.cancel()
-        return self.cancel()
+        if self._cancel is not None:
+            self._cancel.set_result(None)
+        return super().cancel()
+
+    async def wait_for_cancel(self):
+        """Await for step to be cancelled by the user. Should be called from insider the averager."""
+        await self._cancel

+ 31 - 0
tests/test_averaging.py

@@ -467,6 +467,37 @@ def test_averaging_trigger():
     c0.allow_allreduce()
 
 
+@pytest.mark.forked
+def test_averaging_cancel():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            client_mode=(i % 2 == 0),
+            prefix="mygroup",
+            start=True,
+        )
+        for i, dht in enumerate(launch_dht_instances(4))
+    )
+
+    step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
+
+    time.sleep(0.2)
+    step_controls[0].cancel()
+    step_controls[1].cancel()
+
+    for i, control in enumerate(step_controls):
+        if i in (0, 1):
+            assert control.cancelled()
+        else:
+            assert control.result() is not None and len(control.result()) == 2
+
+    for averager in averagers:
+        averager.shutdown()
+
+
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)