|
@@ -25,7 +25,7 @@ from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2, runtime_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
|
|
|
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
|
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
@@ -197,6 +197,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def peer_id(self) -> PeerID:
|
|
|
return self.dht.peer_id
|
|
|
|
|
|
+ @property
|
|
|
+ def request_timeout(self):
|
|
|
+ return self._matchmaking.request_timeout
|
|
|
+
|
|
|
def run(self):
|
|
|
"""
|
|
|
Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
|
|
@@ -211,48 +215,56 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
|
|
|
loop = switch_to_uvloop()
|
|
|
# initialize asyncio synchronization primitives in this event loop
|
|
|
- with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
- async def _run():
|
|
|
+ pipe_semaphore = asyncio.Semaphore(value=0)
|
|
|
+ loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
|
|
|
+
|
|
|
+ async def _run():
|
|
|
+ try:
|
|
|
+ self._p2p = await self.dht.replicate_p2p()
|
|
|
+ if not self.client_mode:
|
|
|
+ await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
+ else:
|
|
|
+ logger.debug(f"The averager is running in client mode.")
|
|
|
+
|
|
|
+ self._matchmaking = Matchmaking(
|
|
|
+ self._p2p,
|
|
|
+ self.schema_hash,
|
|
|
+ self.dht,
|
|
|
+ client_mode=self.client_mode,
|
|
|
+ **self.matchmaking_kwargs,
|
|
|
+ )
|
|
|
+ if not self.client_mode:
|
|
|
+ asyncio.create_task(self._declare_for_download_periodically())
|
|
|
+
|
|
|
+ self._pending_group_assembled = asyncio.Event()
|
|
|
+ self._pending_group_assembled.set()
|
|
|
+ except Exception as e:
|
|
|
+ # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
+ logger.debug(e, exc_info=True)
|
|
|
+ self._ready.set_exception(e)
|
|
|
+ return
|
|
|
+ self._ready.set_result(None)
|
|
|
+
|
|
|
+ while True:
|
|
|
try:
|
|
|
- self._p2p = await self.dht.replicate_p2p()
|
|
|
- if not self.client_mode:
|
|
|
- await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
- else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
-
|
|
|
- self._matchmaking = Matchmaking(
|
|
|
- self._p2p,
|
|
|
- self.schema_hash,
|
|
|
- self.dht,
|
|
|
- client_mode=self.client_mode,
|
|
|
- **self.matchmaking_kwargs,
|
|
|
- )
|
|
|
- if not self.client_mode:
|
|
|
- asyncio.create_task(self._declare_for_download_periodically())
|
|
|
-
|
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
- self._pending_group_assembled.set()
|
|
|
- except Exception as e:
|
|
|
- # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
- logger.debug(e, exc_info=True)
|
|
|
- self._ready.set_exception(e)
|
|
|
- return
|
|
|
- self._ready.set_result(None)
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
|
|
|
- except (OSError, ConnectionError) as e:
|
|
|
- logger.exception(e)
|
|
|
- await asyncio.sleep(self._matchmaking.request_timeout)
|
|
|
- continue
|
|
|
- task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
- if method == "_shutdown":
|
|
|
- await task
|
|
|
- break
|
|
|
-
|
|
|
- loop.run_until_complete(_run())
|
|
|
+ await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ pass
|
|
|
+ if not self._inner_pipe.poll():
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ method, args, kwargs = self._inner_pipe.recv()
|
|
|
+ except (OSError, ConnectionError, RuntimeError) as e:
|
|
|
+ logger.exception(e)
|
|
|
+ await asyncio.sleep(self.request_timeout)
|
|
|
+ continue
|
|
|
+ task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
+ if method == "_shutdown":
|
|
|
+ await task
|
|
|
+ break
|
|
|
+
|
|
|
+ loop.run_until_complete(_run())
|
|
|
|
|
|
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
|
|
"""
|
|
@@ -484,7 +496,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
return
|
|
|
|
|
|
- async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
|
|
|
+ async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
|
|
|
yield message
|
|
|
|
|
|
async def _declare_for_download_periodically(self):
|
|
@@ -542,7 +554,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
|
|
|
return await future
|
|
|
|
|
|
- def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
+ def load_state_from_peers(
|
|
|
+ self, wait: bool = True, timeout: Optional[float] = None
|
|
|
+ ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
"""
|
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
|
:returns: on success, return a 2-tuple with (metadata, tensors), where
|
|
@@ -554,7 +568,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
future = MPFuture()
|
|
|
self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
- return future.result() if wait else future
|
|
|
+ return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
async def _load_state_from_peers(self, future: MPFuture):
|
|
|
try:
|
|
@@ -579,7 +593,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
|
|
|
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
- async for message in stream:
|
|
|
+
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
|
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
@@ -603,7 +618,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
finally:
|
|
|
if not future.done():
|
|
|
- logger.warning("Averager could not load state from peers: all requests have failed.")
|
|
|
+ logger.warning("Averager could not load state from peers: none of the requests succeeded.")
|
|
|
future.set_result(None)
|
|
|
|
|
|
def get_group_bits(self, wait: bool = True):
|