Przeglądaj źródła

RemoteExpertWorker into sep file and make it private

Pavel Samygin 3 lat temu
rodzic
commit
3df384abc2

+ 5 - 5
benchmarks/benchmark_throughput.py

@@ -7,7 +7,8 @@ import time
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P, PeerInfo
 from hivemind.p2p import P2P, PeerInfo
@@ -46,11 +47,10 @@ def client_process(
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
 
 
-    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    p2p = _RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    peer_info = PeerInfo(server_peer_id, server_maddrs)
     experts = [
     experts = [
-        RemoteExpert(
-            expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=PeerInfo(server_peer_id, server_maddrs)), p2p=p2p
-        )
+        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
         for i in range(num_experts)
         for i in range(num_experts)
     ]
     ]
 
 

+ 1 - 1
hivemind/moe/client/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.client.moe import RemoteMixtureOfExperts
 from hivemind.moe.client.moe import RemoteMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts

+ 8 - 3
hivemind/moe/client/beam_search.py

@@ -5,7 +5,12 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import (
+    RemoteExpert,
+    RemoteExpertInfo,
+    batch_create_remote_experts,
+    create_remote_experts,
+)
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
@@ -259,7 +264,7 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
-        return RemoteExpertWorker.spawn_experts(result, self.dht, return_future)
+        return create_remote_experts(result, self.dht, return_future)
 
 
     @classmethod
     @classmethod
     async def _find_best_experts(
     async def _find_best_experts(
@@ -382,7 +387,7 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
-        return RemoteExpertWorker.batch_spawn_experts(result, self.dht, return_future)
+        return batch_create_remote_experts(result, self.dht, return_future)
 
 
     @classmethod
     @classmethod
     async def _batch_find_best_experts(
     async def _batch_find_best_experts(

+ 34 - 79
hivemind/moe/client/expert.py

@@ -1,11 +1,8 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import os
 from concurrent.futures import Future
 from concurrent.futures import Future
 from dataclasses import dataclass
 from dataclasses import dataclass
-from queue import Queue
-from threading import Thread
-from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -14,6 +11,7 @@ from torch.autograd.function import once_differentiable
 from hivemind import moe
 from hivemind import moe
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
@@ -24,7 +22,6 @@ from hivemind.utils import (
     nested_compare,
     nested_compare,
     nested_flatten,
     nested_flatten,
     nested_pack,
     nested_pack,
-    switch_to_uvloop,
 )
 )
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
 from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
@@ -88,7 +85,7 @@ class RemoteExpert(nn.Module):
     @property
     @property
     def info(self):
     def info(self):
         if self._rpc_info is None:
         if self._rpc_info is None:
-            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            outputs = _RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
             self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._rpc_info
         return self._rpc_info
 
 
@@ -96,87 +93,45 @@ class RemoteExpert(nn.Module):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
 
 
-class RemoteExpertWorker:
-    """Local thread for managing async tasks related to RemoteExpert"""
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+    experts: List[Optional[RemoteExpert]] = []
+    for i in infos:
+        if i is not None:
+            experts.append(RemoteExpert(i, p2p))
+        else:
+            experts.append(None)
+    return experts
 
 
-    _task_queue: Queue = Queue()
-    _event_thread: Optional[Thread] = None
-    _pid: int = -1
 
 
-    @classmethod
-    def _run(cls):
-        loop = switch_to_uvloop()
+def create_remote_experts(
+    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteExpert]], Future]:
+    if return_future:
 
 
-        async def receive_tasks():
-            while True:
-                cor, future = cls._task_queue.get()
-                try:
-                    result = await cor
-                except Exception as e:
-                    future.set_exception(e)
-                    continue
-                if not future.cancelled():
-                    future.set_result(result)
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
 
 
-        loop.run_until_complete(receive_tasks())
+        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
 
-    @classmethod
-    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
-        if cls._event_thread is None or cls._pid != os.getpid():
-            cls._pid = os.getpid()
-            cls._event_thread = Thread(target=cls._run, daemon=True)
-            cls._event_thread.start()
+    p2p = _RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
 
 
-        future = Future()
-        cls._task_queue.put((coro, future))
 
 
-        if return_future:
-            return future
+def batch_create_remote_experts(
+    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    dht: DHT,
+    return_future: bool = False,
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+    if return_future:
 
 
-        result = future.result()
-        return result
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return [_create_remote_experts(i, p2p) for i in await infos_future]
 
 
-    @classmethod
-    def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
-        experts: List[Optional[RemoteExpert]] = []
-        for i in infos:
-            if i is not None:
-                experts.append(RemoteExpert(i, p2p))
-            else:
-                experts.append(None)
-        return experts
+        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
 
-    @classmethod
-    def spawn_experts(
-        cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-    ) -> Union[List[Optional[RemoteExpert]], Future]:
-        if return_future:
-
-            async def _unpack(infos_future: MPFuture, dht: DHT):
-                p2p = await dht.replicate_p2p()
-                return cls._spawn_experts(await infos_future, p2p)
-
-            return cls.run_coroutine(_unpack(infos, dht), return_future)
-
-        p2p = cls.run_coroutine(dht.replicate_p2p())
-        return cls._spawn_experts(infos, p2p)
-
-    @classmethod
-    def batch_spawn_experts(
-        cls,
-        infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
-        dht: DHT,
-        return_future: bool = False,
-    ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
-        if return_future:
-
-            async def _unpack(infos_future: MPFuture, dht: DHT):
-                p2p = await dht.replicate_p2p()
-                return [cls._spawn_experts(i, p2p) for i in await infos_future]
-
-            return cls.run_coroutine(_unpack(infos, dht), return_future)
-
-        return [cls.spawn_experts(exps, dht) for exps in infos]
+    return [create_remote_experts(exps, dht) for exps in infos]
 
 
 
 
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -266,7 +221,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
-        deserialized_outputs = RemoteExpertWorker.run_coroutine(
+        deserialized_outputs = _RemoteExpertWorker.run_coroutine(
             expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
             expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
         )
         )
 
 
@@ -279,7 +234,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
 
 
-        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+        deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
             expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
             expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
         )
         )
 
 

+ 4 - 3
hivemind/moe/client/moe.py

@@ -14,11 +14,12 @@ from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import (
 from hivemind.moe.client.expert import (
     DUMMY,
     DUMMY,
     RemoteExpert,
     RemoteExpert,
-    RemoteExpertWorker,
     _get_expert_stub,
     _get_expert_stub,
     expert_backward,
     expert_backward,
     expert_forward,
     expert_forward,
 )
 )
+
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils import nested_flatten, nested_map, nested_pack
@@ -234,7 +235,7 @@ class _RemoteCallMany(torch.autograd.Function):
             for j, expert in enumerate(experts_per_sample[i]):
             for j, expert in enumerate(experts_per_sample[i]):
                 compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
                 compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
-                new_task = RemoteExpertWorker.run_coroutine(
+                new_task = _RemoteExpertWorker.run_coroutine(
                     expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
                     expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
                     return_future=True,
                     return_future=True,
                 )
                 )
@@ -326,7 +327,7 @@ class _RemoteCallMany(torch.autograd.Function):
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             compressions = (p.compression for p in backward_schema)
             compressions = (p.compression for p in backward_schema)
-            new_task = RemoteExpertWorker.run_coroutine(
+            new_task = _RemoteExpertWorker.run_coroutine(
                 expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
                 expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
             )
             )
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)

+ 48 - 0
hivemind/moe/client/remote_expert_worker.py

@@ -0,0 +1,48 @@
+import os
+from concurrent.futures import Future
+from queue import Queue
+from threading import Thread
+from typing import Awaitable, Optional
+
+from hivemind.utils import switch_to_uvloop
+
+
+class _RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
+
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+    _pid: int = -1
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+                if not future.cancelled():
+                    future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
+            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        result = future.result()
+        return result

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -3,7 +3,7 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_DELIMITER,
@@ -84,7 +84,7 @@ def get_experts(
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-    return RemoteExpertWorker.spawn_experts(result, dht, return_future)
+    return create_remote_experts(result, dht, return_future)
 
 
 
 
 async def _get_experts(
 async def _get_experts(

+ 3 - 4
tests/test_custom_experts.py

@@ -3,9 +3,8 @@ import os
 import pytest
 import pytest
 import torch
 import torch
 
 
-from hivemind import RemoteExpert
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -22,7 +21,7 @@ def test_custom_expert(hid_dim=16):
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+        expert0, expert1 = create_remote_experts(
             [
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
@@ -53,7 +52,7 @@ def test_multihead_expert(hid_dim=16):
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+        expert0, expert1 = create_remote_experts(
             [
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),

+ 4 - 6
tests/test_moe.py

@@ -1,11 +1,9 @@
-import time
-
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
@@ -78,7 +76,7 @@ def test_call_many(hidden_dim=16):
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
 
 
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
+        e0, e1, e2, e3, e4 = create_remote_experts(
             [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
             [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
             dht,
             dht,
         )
         )
@@ -137,7 +135,7 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
         optim_cls=None,
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
+        real_expert, fake_expert = create_remote_experts(
             [
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
@@ -206,7 +204,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
         optim_cls=None,
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert = RemoteExpertWorker.spawn_experts(
+        expert = create_remote_experts(
             [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
             [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
             dht=dht,
             dht=dht,
         )[0]
         )[0]

+ 2 - 2
tests/test_training.py

@@ -9,7 +9,7 @@ from sklearn.datasets import load_digits
 
 
 from hivemind import DHT
 from hivemind import DHT
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
@@ -24,7 +24,7 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
         num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
-        expert1, expert2 = RemoteExpertWorker.spawn_experts(
+        expert1, expert2 = create_remote_experts(
             [
             [
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
                 RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),