|
@@ -1,3 +1,9 @@
|
|
|
+import asyncio
|
|
|
+import ctypes
|
|
|
+import multiprocessing as mp
|
|
|
+import threading
|
|
|
+import time
|
|
|
+
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
import torch
|
|
@@ -5,12 +11,13 @@ import torch
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
|
|
|
from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
|
|
|
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
|
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
|
from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
|
|
|
from hivemind.moe.server.layers import name_to_block
|
|
|
from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
|
|
|
-from hivemind.utils import BatchTensorDescriptor, get_dht_time
|
|
|
+from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@@ -306,3 +313,43 @@ def test_client_anomaly_detection():
|
|
|
|
|
|
finally:
|
|
|
server.shutdown()
|
|
|
+
|
|
|
+
|
|
|
+def _measure_coro_running_time(n_coros, elapsed_fut, counter):
|
|
|
+ async def coro():
|
|
|
+ await asyncio.sleep(0.1)
|
|
|
+ counter.value += 1
|
|
|
+
|
|
|
+ try:
|
|
|
+ start_time = time.perf_counter()
|
|
|
+
|
|
|
+ futures = [
|
|
|
+ RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
|
|
|
+ ] # Non-blocking calls
|
|
|
+ RemoteExpertWorker.run_coroutine(coro(), return_future=False) # A blocking call
|
|
|
+ for fut in futures:
|
|
|
+ fut.result()
|
|
|
+
|
|
|
+ elapsed_fut.set_result(time.perf_counter() - start_time)
|
|
|
+ except Exception as e:
|
|
|
+ elapsed_fut.set_exception(e)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
|
|
|
+ processes = []
|
|
|
+ counter = mp.Value(ctypes.c_int64)
|
|
|
+ for i in range(n_processes):
|
|
|
+ elapsed_fut = MPFuture()
|
|
|
+ factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes
|
|
|
+
|
|
|
+ proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
|
|
|
+ proc.start()
|
|
|
+ processes.append((proc, elapsed_fut))
|
|
|
+
|
|
|
+ for proc, elapsed_fut in processes:
|
|
|
+ # Ensure that the coroutines were run concurrently, not sequentially
|
|
|
+ assert elapsed_fut.result() < 0.2
|
|
|
+ proc.join()
|
|
|
+
|
|
|
+ assert counter.value == n_processes * n_coros # Ensure all couroutines have finished
|