|
@@ -43,6 +43,7 @@ class StepControl(MPFuture):
|
|
|
super().__init__()
|
|
|
self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
|
|
|
self._trigger: Optional[MPFuture] = None
|
|
|
+ self._cancel: Optional[MPFuture] = None
|
|
|
|
|
|
# Buffer contents:
|
|
|
# scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
|
|
@@ -52,12 +53,12 @@ class StepControl(MPFuture):
|
|
|
self.weight = weight
|
|
|
self.began_allreduce = False
|
|
|
|
|
|
- def attach_trigger(self, trigger: MPFuture):
|
|
|
- assert self._trigger is None, "Trigger is already attached"
|
|
|
- self._trigger = trigger
|
|
|
+ def attach(self, trigger: MPFuture, cancel: MPFuture):
|
|
|
+ assert self._trigger is None and self._cancel is None, "Futures are already attached"
|
|
|
+ self._trigger, self._cancel = trigger, cancel
|
|
|
|
|
|
def allow_allreduce(self):
|
|
|
- """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
|
|
|
+ """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
|
|
|
assert self._trigger is not None, "StepControl does not have an attached trigger"
|
|
|
if self._trigger.done():
|
|
|
logger.warning("Trigger is already set")
|
|
@@ -133,16 +134,23 @@ class StepControl(MPFuture):
|
|
|
return dict(
|
|
|
super().__getstate__(),
|
|
|
_trigger=self._trigger,
|
|
|
+ _cancel=self._cancel,
|
|
|
_shared_buffer=self._shared_buffer,
|
|
|
immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
|
|
|
)
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
super().__setstate__(state)
|
|
|
- self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
|
|
|
+ self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
|
|
|
self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
|
|
|
|
|
|
def cancel(self) -> bool:
|
|
|
if self._trigger is not None:
|
|
|
self._trigger.cancel()
|
|
|
- return self.cancel()
|
|
|
+ if self._cancel is not None:
|
|
|
+ self._cancel.set_result(None)
|
|
|
+ return super().cancel()
|
|
|
+
|
|
|
+ async def wait_for_cancel(self):
|
|
|
+ """Await for step to be cancelled by the user. Should be called from insider the averager."""
|
|
|
+ await self._cancel
|