Prechádzať zdrojové kódy

Allow RemoteExpertWorker run coroutines concurrently (#561)

Previously, `RemoteExpertWorker` ran one coroutine at a time, so hivemind.moe/Petals clients were very slow for concurrent calls.

(cherry picked from commit 589cb2c8b7c9d655a1250405672a0f4ab1d24f59)
Alexander Borzunov 2 rokov pred
rodič
commit
542f5c3142

+ 1 - 0
.github/workflows/check-style.yml

@@ -32,3 +32,4 @@ jobs:
       - uses: codespell-project/actions-codespell@v1
         with:
           only_warn: 1
+          ignore_words_list: ibrary,nd

+ 16 - 28
hivemind/moe/client/remote_expert_worker.py

@@ -1,6 +1,6 @@
+import asyncio
 import os
 from concurrent.futures import Future
-from queue import Queue
 from threading import Thread
 from typing import Awaitable, Optional
 
@@ -10,39 +10,27 @@ from hivemind.utils import switch_to_uvloop
 class RemoteExpertWorker:
     """Local thread for managing async tasks related to RemoteExpert"""
 
-    _task_queue: Queue = Queue()
-    _event_thread: Optional[Thread] = None
-    _pid: int = -1
+    _event_thread = None
+    _event_loop_fut = None
+    _pid = None
 
     @classmethod
-    def _run(cls):
-        loop = switch_to_uvloop()
-
-        async def receive_tasks():
-            while True:
-                cor, future = cls._task_queue.get()
-                try:
-                    result = await cor
-                except Exception as e:
-                    future.set_exception(e)
-                    continue
-                if not future.cancelled():
-                    future.set_result(result)
-
-        loop.run_until_complete(receive_tasks())
+    def _run_event_loop(cls):
+        try:
+            loop = switch_to_uvloop()
+            cls._event_loop_fut.set_result(loop)
+        except Exception as e:
+            cls._event_loop_fut.set_exception(e)
+        loop.run_forever()
 
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
         if cls._event_thread is None or cls._pid != os.getpid():
             cls._pid = os.getpid()
-            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_loop_fut = Future()
+            cls._event_thread = Thread(target=cls._run_event_loop, daemon=True)
             cls._event_thread.start()
 
-        future = Future()
-        cls._task_queue.put((coro, future))
-
-        if return_future:
-            return future
-
-        result = future.result()
-        return result
+        loop = cls._event_loop_fut.result()
+        future = asyncio.run_coroutine_threadsafe(coro, loop)
+        return future if return_future else future.result()

+ 2 - 2
hivemind/p2p/servicer.py

@@ -18,7 +18,7 @@ class RPCHandler:
 
 class StubBase:
     """
-    Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
+    Base class for P2P RPC stubs. The interface mimics gRPC stubs.
 
     Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
@@ -32,7 +32,7 @@ class StubBase:
 
 class ServicerBase:
     """
-    Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
+    Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimics gRPC servicers.
 
     - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value

+ 48 - 1
tests/test_moe.py

@@ -1,3 +1,9 @@
+import asyncio
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+
 import numpy as np
 import pytest
 import torch
@@ -5,12 +11,13 @@ import torch
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
-from hivemind.utils import BatchTensorDescriptor, get_dht_time
+from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
 
 
 @pytest.mark.forked
@@ -306,3 +313,43 @@ def test_client_anomaly_detection():
 
     finally:
         server.shutdown()
+
+
+def _measure_coro_running_time(n_coros, elapsed_fut, counter):
+    async def coro():
+        await asyncio.sleep(0.1)
+        counter.value += 1
+
+    try:
+        start_time = time.perf_counter()
+
+        futures = [
+            RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
+        ]  # Non-blocking calls
+        RemoteExpertWorker.run_coroutine(coro(), return_future=False)  # A blocking call
+        for fut in futures:
+            fut.result()
+
+        elapsed_fut.set_result(time.perf_counter() - start_time)
+    except Exception as e:
+        elapsed_fut.set_exception(e)
+
+
+@pytest.mark.forked
+def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
+    processes = []
+    counter = mp.Value(ctypes.c_int64)
+    for i in range(n_processes):
+        elapsed_fut = MPFuture()
+        factory = threading.Thread if i % 2 == 0 else mp.Process  # Test both threads and processes
+
+        proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
+        proc.start()
+        processes.append((proc, elapsed_fut))
+
+    for proc, elapsed_fut in processes:
+        # Ensure that the coroutines were run concurrently, not sequentially
+        assert elapsed_fut.result() < 0.2
+        proc.join()
+
+    assert counter.value == n_processes * n_coros  # Ensure all couroutines have finished