|
@@ -554,7 +554,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
|
|
|
return await future
|
|
|
|
|
|
- def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
+ def load_state_from_peers(self, wait=True, timeout: Optional[float] = None) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
"""
|
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
|
:returns: on success, return a 2-tuple with (metadata, tensors), where
|
|
@@ -566,7 +566,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
future = MPFuture()
|
|
|
self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
- return future.result() if wait else future
|
|
|
+ return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
async def _load_state_from_peers(self, future: MPFuture):
|
|
|
try:
|