浏览代码

pipe semaphore

justheuristic 4 年之前
父节点
当前提交
b2decce266
共有 2 个文件被更改,包括 89 次插入73 次删除
  1. 46 37
      hivemind/averaging/averager.py
  2. 43 36
      hivemind/dht/__init__.py

+ 46 - 37
hivemind/averaging/averager.py

@@ -212,46 +212,55 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                if not self.client_mode:
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                else:
+                    logger.debug(f"The averager is running in client mode.")
+
+                self._matchmaking = Matchmaking(
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
+                )
+                if not self.client_mode:
+                    asyncio.create_task(self._declare_for_download_periodically())
+
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
                 try:
-                    self._p2p = await self.dht.replicate_p2p()
-                    if not self.client_mode:
-                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                    else:
-                        logger.debug(f"The averager is running in client mode.")
-
-                    self._matchmaking = Matchmaking(
-                        self._p2p,
-                        self.schema_hash,
-                        self.dht,
-                        client_mode=self.client_mode,
-                        **self.matchmaking_kwargs,
-                    )
-                    if not self.client_mode:
-                        asyncio.create_task(self._declare_for_download_periodically())
-
-                    self._pending_group_assembled = asyncio.Event()
-                    self._pending_group_assembled.set()
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._matchmaking.request_timeout)
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._matchmaking.request_timeout)
+                except asyncio.TimeoutError:
+                    pass
+
+                try:
+                    if not self._inner_pipe.poll():
                         continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._matchmaking.request_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
 
             loop.run_until_complete(_run())
 

+ 43 - 36
hivemind/dht/__init__.py

@@ -103,45 +103,52 @@ class DHT(mp.Process):
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
         print(f"RUNNING DHT WITH PID={self.pid}")
-        loop = switch_to_uvloop()
-
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
                 try:
-                    if self._daemon_listen_maddr is not None:
-                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                    else:
-                        replicated_p2p = None
-
-                    self._node = await DHTNode.create(
-                        initial_peers=self.initial_peers,
-                        num_workers=self.num_workers,
-                        record_validator=self._record_validator,
-                        p2p=replicated_p2p,
-                        **self.kwargs,
-                    )
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._node.protocol.wait_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
 
-            coro = _run()
-            loop.run_until_complete(coro)
+                    if not self._inner_pipe.poll():
+                        continue
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """