justheuristic 3 tahun lalu
induk
melakukan
f92d6c325b
2 mengubah file dengan 12 tambahan dan 3 penghapusan
  1. 3 3
      hivemind/averaging/averager.py
  2. 9 0
      hivemind/averaging/control.py

+ 3 - 3
hivemind/averaging/averager.py

@@ -306,7 +306,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def shutdown(self) -> None:
         """Shut down the averager process"""
         if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {}))  # shut down the daemon process
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
@@ -315,11 +315,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
-    async def _shutdown(self) -> None:
+    async def _shutdown(self, timeout: Optional[DHTExpiration]) -> None:
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        await asyncio.gather(*remaining_tasks)
+        await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
 
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():

+ 9 - 0
hivemind/averaging/control.py

@@ -113,6 +113,15 @@ class StepControl(MPFuture):
     def allow_retries(self) -> bool:
         return self._allow_retries
 
+    def __getstate__(self):
+        return dict(super().__getstate__(), _trigger=self._trigger, _shared_buffer=self._shared_buffer,
+                    immutable_params=(self._gather_binary, self._deadline, self._allow_retries))
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
+        self._gather_binary, self._deadline, self._allow_retries = state["immutable_params"]
+
     def cancel(self) -> bool:
         if self._trigger is not None:
             self._trigger.cancel()