Browse Source

pipe semaphore

justheuristic 4 years ago
parent
commit
b2decce266
2 changed files with 89 additions and 73 deletions
  1. 46 37
      hivemind/averaging/averager.py
  2. 43 36
      hivemind/dht/__init__.py

+ 46 - 37
hivemind/averaging/averager.py

@@ -212,46 +212,55 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
-            async def _run():
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                if not self.client_mode:
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                else:
+                    logger.debug(f"The averager is running in client mode.")
+
+                self._matchmaking = Matchmaking(
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
+                )
+                if not self.client_mode:
+                    asyncio.create_task(self._declare_for_download_periodically())
+
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
                 try:
                 try:
-                    self._p2p = await self.dht.replicate_p2p()
-                    if not self.client_mode:
-                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                    else:
-                        logger.debug(f"The averager is running in client mode.")
-
-                    self._matchmaking = Matchmaking(
-                        self._p2p,
-                        self.schema_hash,
-                        self.dht,
-                        client_mode=self.client_mode,
-                        **self.matchmaking_kwargs,
-                    )
-                    if not self.client_mode:
-                        asyncio.create_task(self._declare_for_download_periodically())
-
-                    self._pending_group_assembled = asyncio.Event()
-                    self._pending_group_assembled.set()
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._matchmaking.request_timeout)
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._matchmaking.request_timeout)
+                except asyncio.TimeoutError:
+                    pass
+
+                try:
+                    if not self._inner_pipe.poll():
                         continue
                         continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._matchmaking.request_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
 
 
             loop.run_until_complete(_run())
             loop.run_until_complete(_run())
 
 

+ 43 - 36
hivemind/dht/__init__.py

@@ -103,45 +103,52 @@ class DHT(mp.Process):
     def run(self) -> None:
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
         """Serve DHT forever. This function will not return until DHT node is shut down"""
         print(f"RUNNING DHT WITH PID={self.pid}")
         print(f"RUNNING DHT WITH PID={self.pid}")
-        loop = switch_to_uvloop()
-
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
-            async def _run():
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
                 try:
                 try:
-                    if self._daemon_listen_maddr is not None:
-                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                    else:
-                        replicated_p2p = None
-
-                    self._node = await DHTNode.create(
-                        initial_peers=self.initial_peers,
-                        num_workers=self.num_workers,
-                        record_validator=self._record_validator,
-                        p2p=replicated_p2p,
-                        **self.kwargs,
-                    )
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._node.protocol.wait_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
 
 
-            coro = _run()
-            loop.run_until_complete(coro)
+                    if not self._inner_pipe.poll():
+                        continue
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         """