Răsfoiți Sursa

Allow RemoteExpertWorker run coroutines concurrently (#561)

Previously, `RemoteExpertWorker` ran one coroutine at a time, so hivemind.moe/Petals clients were very slow for concurrent calls.

(cherry picked from commit 589cb2c8b7c9d655a1250405672a0f4ab1d24f59)
Alexander Borzunov 2 ani în urmă
părinte
comite
542f5c3142

+ 1 - 0
.github/workflows/check-style.yml

@@ -32,3 +32,4 @@ jobs:
       - uses: codespell-project/actions-codespell@v1
       - uses: codespell-project/actions-codespell@v1
         with:
         with:
           only_warn: 1
           only_warn: 1
+          ignore_words_list: ibrary,nd

+ 16 - 28
hivemind/moe/client/remote_expert_worker.py

@@ -1,6 +1,6 @@
+import asyncio
 import os
 import os
 from concurrent.futures import Future
 from concurrent.futures import Future
-from queue import Queue
 from threading import Thread
 from threading import Thread
 from typing import Awaitable, Optional
 from typing import Awaitable, Optional
 
 
@@ -10,39 +10,27 @@ from hivemind.utils import switch_to_uvloop
 class RemoteExpertWorker:
 class RemoteExpertWorker:
     """Local thread for managing async tasks related to RemoteExpert"""
     """Local thread for managing async tasks related to RemoteExpert"""
 
 
-    _task_queue: Queue = Queue()
-    _event_thread: Optional[Thread] = None
-    _pid: int = -1
+    _event_thread = None
+    _event_loop_fut = None
+    _pid = None
 
 
     @classmethod
     @classmethod
-    def _run(cls):
-        loop = switch_to_uvloop()
-
-        async def receive_tasks():
-            while True:
-                cor, future = cls._task_queue.get()
-                try:
-                    result = await cor
-                except Exception as e:
-                    future.set_exception(e)
-                    continue
-                if not future.cancelled():
-                    future.set_result(result)
-
-        loop.run_until_complete(receive_tasks())
+    def _run_event_loop(cls):
+        try:
+            loop = switch_to_uvloop()
+            cls._event_loop_fut.set_result(loop)
+        except Exception as e:
+            cls._event_loop_fut.set_exception(e)
+        loop.run_forever()
 
 
     @classmethod
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
         if cls._event_thread is None or cls._pid != os.getpid():
         if cls._event_thread is None or cls._pid != os.getpid():
             cls._pid = os.getpid()
             cls._pid = os.getpid()
-            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_loop_fut = Future()
+            cls._event_thread = Thread(target=cls._run_event_loop, daemon=True)
             cls._event_thread.start()
             cls._event_thread.start()
 
 
-        future = Future()
-        cls._task_queue.put((coro, future))
-
-        if return_future:
-            return future
-
-        result = future.result()
-        return result
+        loop = cls._event_loop_fut.result()
+        future = asyncio.run_coroutine_threadsafe(coro, loop)
+        return future if return_future else future.result()

+ 2 - 2
hivemind/p2p/servicer.py

@@ -18,7 +18,7 @@ class RPCHandler:
 
 
 class StubBase:
 class StubBase:
     """
     """
-    Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
+    Base class for P2P RPC stubs. The interface mimics gRPC stubs.
 
 
     Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
     Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
@@ -32,7 +32,7 @@ class StubBase:
 
 
 class ServicerBase:
 class ServicerBase:
     """
     """
-    Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
+    Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimics gRPC servicers.
 
 
     - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
     - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value

+ 48 - 1
tests/test_moe.py

@@ -1,3 +1,9 @@
+import asyncio
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import torch
 import torch
@@ -5,12 +11,13 @@ import torch
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
-from hivemind.utils import BatchTensorDescriptor, get_dht_time
+from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -306,3 +313,43 @@ def test_client_anomaly_detection():
 
 
     finally:
     finally:
         server.shutdown()
         server.shutdown()
+
+
+def _measure_coro_running_time(n_coros, elapsed_fut, counter):
+    async def coro():
+        await asyncio.sleep(0.1)
+        counter.value += 1
+
+    try:
+        start_time = time.perf_counter()
+
+        futures = [
+            RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
+        ]  # Non-blocking calls
+        RemoteExpertWorker.run_coroutine(coro(), return_future=False)  # A blocking call
+        for fut in futures:
+            fut.result()
+
+        elapsed_fut.set_result(time.perf_counter() - start_time)
+    except Exception as e:
+        elapsed_fut.set_exception(e)
+
+
+@pytest.mark.forked
+def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
+    processes = []
+    counter = mp.Value(ctypes.c_int64)
+    for i in range(n_processes):
+        elapsed_fut = MPFuture()
+        factory = threading.Thread if i % 2 == 0 else mp.Process  # Test both threads and processes
+
+        proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
+        proc.start()
+        processes.append((proc, elapsed_fut))
+
+    for proc, elapsed_fut in processes:
+        # Ensure that the coroutines were run concurrently, not sequentially
+        assert elapsed_fut.result() < 0.2
+        proc.join()
+
+    assert counter.value == n_processes * n_coros  # Ensure all couroutines have finished