|
@@ -212,46 +212,55 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
|
|
|
loop = switch_to_uvloop()
|
|
|
# initialize asyncio synchronization primitives in this event loop
|
|
|
- with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
- async def _run():
|
|
|
+ pipe_semaphore = asyncio.Semaphore(value=0)
|
|
|
+ loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
|
|
|
+
|
|
|
+ async def _run():
|
|
|
+ try:
|
|
|
+ self._p2p = await self.dht.replicate_p2p()
|
|
|
+ if not self.client_mode:
|
|
|
+ await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
+ else:
|
|
|
+ logger.debug(f"The averager is running in client mode.")
|
|
|
+
|
|
|
+ self._matchmaking = Matchmaking(
|
|
|
+ self._p2p,
|
|
|
+ self.schema_hash,
|
|
|
+ self.dht,
|
|
|
+ client_mode=self.client_mode,
|
|
|
+ **self.matchmaking_kwargs,
|
|
|
+ )
|
|
|
+ if not self.client_mode:
|
|
|
+ asyncio.create_task(self._declare_for_download_periodically())
|
|
|
+
|
|
|
+ self._pending_group_assembled = asyncio.Event()
|
|
|
+ self._pending_group_assembled.set()
|
|
|
+ except Exception as e:
|
|
|
+ # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
+ logger.debug(e, exc_info=True)
|
|
|
+ self._ready.set_exception(e)
|
|
|
+ return
|
|
|
+ self._ready.set_result(None)
|
|
|
+
|
|
|
+ while True:
|
|
|
try:
|
|
|
- self._p2p = await self.dht.replicate_p2p()
|
|
|
- if not self.client_mode:
|
|
|
- await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
- else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
-
|
|
|
- self._matchmaking = Matchmaking(
|
|
|
- self._p2p,
|
|
|
- self.schema_hash,
|
|
|
- self.dht,
|
|
|
- client_mode=self.client_mode,
|
|
|
- **self.matchmaking_kwargs,
|
|
|
- )
|
|
|
- if not self.client_mode:
|
|
|
- asyncio.create_task(self._declare_for_download_periodically())
|
|
|
-
|
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
- self._pending_group_assembled.set()
|
|
|
- except Exception as e:
|
|
|
- # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
- logger.debug(e, exc_info=True)
|
|
|
- self._ready.set_exception(e)
|
|
|
- return
|
|
|
- self._ready.set_result(None)
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
|
|
|
- except (OSError, ConnectionError) as e:
|
|
|
- logger.exception(e)
|
|
|
- await asyncio.sleep(self._matchmaking.request_timeout)
|
|
|
+ await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._matchmaking.request_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ try:
|
|
|
+ if not self._inner_pipe.poll():
|
|
|
continue
|
|
|
- task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
- if method == "_shutdown":
|
|
|
- await task
|
|
|
- break
|
|
|
+ method, args, kwargs = self._inner_pipe.recv()
|
|
|
+ except (OSError, ConnectionError, RuntimeError) as e:
|
|
|
+ logger.exception(e)
|
|
|
+ await asyncio.sleep(self._matchmaking.request_timeout)
|
|
|
+ continue
|
|
|
+ task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
+ if method == "_shutdown":
|
|
|
+ await task
|
|
|
+ break
|
|
|
|
|
|
loop.run_until_complete(_run())
|
|
|
|