|
@@ -25,7 +25,7 @@ from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2, runtime_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
|
|
|
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop, aiter_with_timeout
|
|
|
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
@@ -197,6 +197,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def peer_id(self) -> PeerID:
|
|
|
return self.dht.peer_id
|
|
|
|
|
|
+ @property
|
|
|
+ def request_timeout(self):
|
|
|
+ return self._matchmaking.request_timeout
|
|
|
+
|
|
|
def run(self):
|
|
|
"""
|
|
|
Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
|
|
@@ -245,7 +249,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
- await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._matchmaking.request_timeout)
|
|
|
+ await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
|
|
|
except asyncio.TimeoutError:
|
|
|
pass
|
|
|
if not self._inner_pipe.poll():
|
|
@@ -254,7 +258,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
method, args, kwargs = self._inner_pipe.recv()
|
|
|
except (OSError, ConnectionError, RuntimeError) as e:
|
|
|
logger.exception(e)
|
|
|
- await asyncio.sleep(self._matchmaking.request_timeout)
|
|
|
+ await asyncio.sleep(self.request_timeout)
|
|
|
continue
|
|
|
task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
if method == "_shutdown":
|
|
@@ -588,7 +592,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
|
|
|
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
- async for message in stream:
|
|
|
+
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=self._matchmaking.request_timeout):
|
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|