Bläddra i källkod

RemoteExpertWorker into sep file and make it private

Pavel Samygin 3 år sedan
förälder
incheckning
3df384abc2

+ 5 - 5
benchmarks/benchmark_throughput.py

@@ -7,7 +7,8 @@ import time
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P, PeerInfo
@@ -46,11 +47,10 @@ def client_process(
     torch.set_num_threads(1)
     can_start.wait()
 
-    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    p2p = _RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    peer_info = PeerInfo(server_peer_id, server_maddrs)
     experts = [
-        RemoteExpert(
-            expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=PeerInfo(server_peer_id, server_maddrs)), p2p=p2p
-        )
+        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
         for i in range(num_experts)
     ]
 

+ 1 - 1
hivemind/moe/client/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.client.moe import RemoteMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts

+ 8 - 3
hivemind/moe/client/beam_search.py

@@ -5,7 +5,12 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import (
+    RemoteExpert,
+    RemoteExpertInfo,
+    batch_create_remote_experts,
+    create_remote_experts,
+)
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
@@ -259,7 +264,7 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        return RemoteExpertWorker.spawn_experts(result, self.dht, return_future)
+        return create_remote_experts(result, self.dht, return_future)
 
     @classmethod
     async def _find_best_experts(
@@ -382,7 +387,7 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        return RemoteExpertWorker.batch_spawn_experts(result, self.dht, return_future)
+        return batch_create_remote_experts(result, self.dht, return_future)
 
     @classmethod
     async def _batch_find_best_experts(

+ 34 - 79
hivemind/moe/client/expert.py

@@ -1,11 +1,8 @@
 from __future__ import annotations
 
-import os
 from concurrent.futures import Future
 from dataclasses import dataclass
-from queue import Queue
-from threading import Thread
-from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -14,6 +11,7 @@ from torch.autograd.function import once_differentiable
 from hivemind import moe
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
@@ -24,7 +22,6 @@ from hivemind.utils import (
     nested_compare,
     nested_flatten,
     nested_pack,
-    switch_to_uvloop,
 )
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
@@ -88,7 +85,7 @@ class RemoteExpert(nn.Module):
     @property
     def info(self):
         if self._rpc_info is None:
-            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            outputs = _RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._rpc_info
 
@@ -96,87 +93,45 @@ class RemoteExpert(nn.Module):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
-class RemoteExpertWorker:
-    """Local thread for managing async tasks related to RemoteExpert"""
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+    experts: List[Optional[RemoteExpert]] = []
+    for i in infos:
+        if i is not None:
+            experts.append(RemoteExpert(i, p2p))
+        else:
+            experts.append(None)
+    return experts
 
-    _task_queue: Queue = Queue()
-    _event_thread: Optional[Thread] = None
-    _pid: int = -1
 
-    @classmethod
-    def _run(cls):
-        loop = switch_to_uvloop()
+def create_remote_experts(
+    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteExpert]], Future]:
+    if return_future:
 
-        async def receive_tasks():
-            while True:
-                cor, future = cls._task_queue.get()
-                try:
-                    result = await cor
-                except Exception as e:
-                    future.set_exception(e)
-                    continue
-                if not future.cancelled():
-                    future.set_result(result)
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
 
-        loop.run_until_complete(receive_tasks())
+        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
-    @classmethod
-    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
-        if cls._event_thread is None or cls._pid != os.getpid():
-            cls._pid = os.getpid()
-            cls._event_thread = Thread(target=cls._run, daemon=True)
-            cls._event_thread.start()
+    p2p = _RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
 
-        future = Future()
-        cls._task_queue.put((coro, future))
 
-        if return_future:
-            return future
+def batch_create_remote_experts(
+    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    dht: DHT,
+    return_future: bool = False,
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+    if return_future:
 
-        result = future.result()
-        return result
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return [_create_remote_experts(i, p2p) for i in await infos_future]
 
-    @classmethod
-    def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
-        experts: List[Optional[RemoteExpert]] = []
-        for i in infos:
-            if i is not None:
-                experts.append(RemoteExpert(i, p2p))
-            else:
-                experts.append(None)
-        return experts
+        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
-    @classmethod
-    def spawn_experts(
-        cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-    ) -> Union[List[Optional[RemoteExpert]], Future]:
-        if return_future:
-
-            async def _unpack(infos_future: MPFuture, dht: DHT):
-                p2p = await dht.replicate_p2p()
-                return cls._spawn_experts(await infos_future, p2p)
-
-            return cls.run_coroutine(_unpack(infos, dht), return_future)
-
-        p2p = cls.run_coroutine(dht.replicate_p2p())
-        return cls._spawn_experts(infos, p2p)
-
-    @classmethod
-    def batch_spawn_experts(
-        cls,
-        infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
-        dht: DHT,
-        return_future: bool = False,
-    ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
-        if return_future:
-
-            async def _unpack(infos_future: MPFuture, dht: DHT):
-                p2p = await dht.replicate_p2p()
-                return [cls._spawn_experts(i, p2p) for i in await infos_future]
-
-            return cls.run_coroutine(_unpack(infos, dht), return_future)
-
-        return [cls.spawn_experts(exps, dht) for exps in infos]
+    return [create_remote_experts(exps, dht) for exps in infos]
 
 
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -266,7 +221,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
-        deserialized_outputs = RemoteExpertWorker.run_coroutine(
+        deserialized_outputs = _RemoteExpertWorker.run_coroutine(
             expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
         )
 
@@ -279,7 +234,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
 
-        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+        deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
             expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
         )
 

+ 4 - 3
hivemind/moe/client/moe.py

@@ -14,11 +14,12 @@ from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import (
     DUMMY,
     RemoteExpert,
-    RemoteExpertWorker,
     _get_expert_stub,
     expert_backward,
     expert_forward,
 )
+
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
@@ -234,7 +235,7 @@ class _RemoteCallMany(torch.autograd.Function):
             for j, expert in enumerate(experts_per_sample[i]):
                 compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
-                new_task = RemoteExpertWorker.run_coroutine(
+                new_task = _RemoteExpertWorker.run_coroutine(
                     expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
                     return_future=True,
                 )
@@ -326,7 +327,7 @@ class _RemoteCallMany(torch.autograd.Function):
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             compressions = (p.compression for p in backward_schema)
-            new_task = RemoteExpertWorker.run_coroutine(
+            new_task = _RemoteExpertWorker.run_coroutine(
                 expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
             )
             pending_tasks[new_task] = (i, j)

+ 48 - 0
hivemind/moe/client/remote_expert_worker.py

@@ -0,0 +1,48 @@
+import os
+from concurrent.futures import Future
+from queue import Queue
+from threading import Thread
+from typing import Awaitable, Optional
+
+from hivemind.utils import switch_to_uvloop
+
+
+class _RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
+
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+    _pid: int = -1
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+                if not future.cancelled():
+                    future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
+            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        result = future.result()
+        return result

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -3,7 +3,7 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
@@ -84,7 +84,7 @@ def get_experts(
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-    return RemoteExpertWorker.spawn_experts(result, dht, return_future)
+    return create_remote_experts(result, dht, return_future)
 
 
 async def _get_experts(

+ 3 - 4
tests/test_custom_experts.py

@@ -3,9 +3,8 @@ import os
 import pytest
 import torch
 
-from hivemind import RemoteExpert
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -22,7 +21,7 @@ def test_custom_expert(hid_dim=16):
         custom_module_path=CUSTOM_EXPERTS_PATH,
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+        expert0, expert1 = create_remote_experts(
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
@@ -53,7 +52,7 @@ def test_multihead_expert(hid_dim=16):
         custom_module_path=CUSTOM_EXPERTS_PATH,
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+        expert0, expert1 = create_remote_experts(
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),

+ 4 - 6
tests/test_moe.py

@@ -1,11 +1,9 @@
-import time
-
 import numpy as np
 import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
@@ -78,7 +76,7 @@ def test_call_many(hidden_dim=16):
         inputs_clone = inputs.clone().detach().requires_grad_(True)
 
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
+        e0, e1, e2, e3, e4 = create_remote_experts(
             [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
             dht,
         )
@@ -137,7 +135,7 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
+        real_expert, fake_expert = create_remote_experts(
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
@@ -206,7 +204,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert = RemoteExpertWorker.spawn_experts(
+        expert = create_remote_experts(
             [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
             dht=dht,
         )[0]

+ 2 - 2
tests/test_training.py

@@ -9,7 +9,7 @@ from sklearn.datasets import load_digits
 
 from hivemind import DHT
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
@@ -24,7 +24,7 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert1, expert2 = RemoteExpertWorker.spawn_experts(
+        expert1, expert2 = create_remote_experts(
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),