|
@@ -170,7 +170,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
|
|
|
- self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
|
+ self._ready = MPFuture()
|
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
|
background_fetcher = threading.Thread(
|
|
|
daemon=True,
|
|
@@ -214,25 +214,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
async def _run():
|
|
|
- self._p2p = await self.dht.replicate_p2p()
|
|
|
- if not self.client_mode:
|
|
|
- await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
- else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
-
|
|
|
- self._matchmaking = Matchmaking(
|
|
|
- self._p2p,
|
|
|
- self.schema_hash,
|
|
|
- self.dht,
|
|
|
- client_mode=self.client_mode,
|
|
|
- **self.matchmaking_kwargs,
|
|
|
- )
|
|
|
- if not self.client_mode:
|
|
|
- asyncio.create_task(self._declare_for_download_periodically())
|
|
|
-
|
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
- self._pending_group_assembled.set()
|
|
|
- self.ready.set()
|
|
|
+ try:
|
|
|
+ self._p2p = await self.dht.replicate_p2p()
|
|
|
+ if not self.client_mode:
|
|
|
+ await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
+ else:
|
|
|
+ logger.debug(f"The averager is running in client mode.")
|
|
|
+
|
|
|
+ self._matchmaking = Matchmaking(
|
|
|
+ self._p2p,
|
|
|
+ self.schema_hash,
|
|
|
+ self.dht,
|
|
|
+ client_mode=self.client_mode,
|
|
|
+ **self.matchmaking_kwargs,
|
|
|
+ )
|
|
|
+ if not self.client_mode:
|
|
|
+ asyncio.create_task(self._declare_for_download_periodically())
|
|
|
+
|
|
|
+ self._pending_group_assembled = asyncio.Event()
|
|
|
+ self._pending_group_assembled.set()
|
|
|
+ except Exception as e:
|
|
|
+ # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
+ logger.debug(e, exc_info=True)
|
|
|
+ self._ready.set_exception(e)
|
|
|
+ return
|
|
|
+ self._ready.set_result(None)
|
|
|
|
|
|
while True:
|
|
|
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
|
|
@@ -243,14 +249,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
loop.run_until_complete(_run())
|
|
|
|
|
|
- def run_in_background(self, await_ready=True, timeout=None):
|
|
|
+ def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
|
|
"""
|
|
|
Starts averager in a background process. if await_ready, this method will wait until background dht
|
|
|
is ready to process incoming requests or for :timeout: seconds max.
|
|
|
"""
|
|
|
self.start()
|
|
|
- if await_ready and not self.ready.wait(timeout=timeout):
|
|
|
- raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
|
|
|
+ if await_ready:
|
|
|
+ self.wait_until_ready(timeout)
|
|
|
+
|
|
|
+ def wait_until_ready(self, timeout: Optional[float] = None) -> None:
|
|
|
+ self._ready.result(timeout=timeout)
|
|
|
|
|
|
def shutdown(self) -> None:
|
|
|
"""Shut down the averager process"""
|