justheuristic 4 年 前
コミット
3e3b039aac
1 ファイル変更2 行追加2 行削除
  1. 2 2
      hivemind/averaging/averager.py

+ 2 - 2
hivemind/averaging/averager.py

@@ -554,7 +554,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
         return await future
 
-    def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
+    def load_state_from_peers(self, wait=True, timeout: Optional[float] = None) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
         :returns: on success, return a 2-tuple with (metadata, tensors), where
@@ -566,7 +566,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         future = MPFuture()
         self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
-        return future.result() if wait else future
+        return future.result(timeout=timeout) if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
         try: