Ver código fonte

fix moe and tests

Pavel Samygin 3 anos atrás
pai
commit
5ba4f72dba

+ 1 - 1
hivemind/compression/base.py

@@ -80,7 +80,7 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.numpy()
+        array = tensor.detach().numpy()
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             buffer=array.tobytes(),

+ 0 - 1
hivemind/dht/dht.py

@@ -310,7 +310,6 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
-
         if self._p2p_replica is None or self._origin_pid != os.getpid():
             self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)

+ 1 - 1
hivemind/moe/client/beam_search.py

@@ -385,7 +385,7 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
-
+        print(result)
         if return_future:
             return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
         return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)

+ 80 - 82
hivemind/moe/client/expert.py

@@ -3,7 +3,7 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from queue import Queue
 from threading import Thread
-from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
 
 import torch
 import torch.nn as nn
@@ -99,7 +99,7 @@ class RemoteExpertWorker:
 
     _task_queue: Queue = Queue()
     _event_thread: Optional[Thread] = None
-    _pid: int = 0
+    _pid: int = -1
 
     @classmethod
     def _run(cls):
@@ -113,7 +113,8 @@ class RemoteExpertWorker:
                 except Exception as e:
                     future.set_exception(e)
                     continue
-                future.set_result(result)
+                if not future.cancelled():
+                    future.set_result(result)
 
         loop.run_until_complete(receive_tasks())
 
@@ -151,7 +152,7 @@ class RemoteExpertWorker:
     @classmethod
     def spawn_experts_future(
         cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> MPFuture[List[Optional[RemoteExpert]]]:
+    ) -> Future[List[Optional[RemoteExpert]]]:
         async def _unpack():
             p2p = cls.run_coroutine(dht.replicate_p2p(), True)
             return cls.spawn_experts(await infos, await p2p)
@@ -166,7 +167,7 @@ class RemoteExpertWorker:
 
     @classmethod
     def spawn_experts_bulk_future(
-        cls, infos: MPFuture[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
+        cls, infos: Future[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
     ) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
         async def _unpack():
             return cls.spawn_experts_bulk(await infos, dht)
@@ -174,6 +175,75 @@ class RemoteExpertWorker:
         return cls.run_coroutine(_unpack, True)
 
 
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            as_aiter(*split),
+        ),
+    )
+
+    return await gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+
+
+async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def expert_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
+) -> List[torch.Tensor]:
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
+    )
+
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size >= DEFAULT_MAX_MSG_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            as_aiter(*split),
+        ),
+    )
+
+    return await gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+
+
+async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
+    )
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size >= DEFAULT_MAX_MSG_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward(uid, serialized_tensors, stub)
+
+
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
@@ -193,93 +263,21 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
-        serialized_tensors = (
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(
+            expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
         )
 
-        size = 0
-        for t in inputs:
-            size += t.element_size() * t.nelement()
-            if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
-                break
-        else:
-            deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
-
         return tuple(deserialized_outputs)
 
-    @classmethod
-    def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
-
-        outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward_stream(
-                amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
-                    as_aiter(*split),
-                ),
-            )
-        )
-
-        return RemoteExpertWorker.run_coroutine(
-            gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
-        )
-
-    @classmethod
-    def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-
-        outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
-        )
-
-        return [deserialize_torch_tensor(t) for t in outputs.tensors]
-
     @classmethod
     @once_differentiable
     def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = (
-            serialize_torch_tensor(tensor, proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        )
-
-        size = 0
-        for t in inputs_and_grad_outputs:
-            size += t.element_size() * t.nelement()
-            if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
-                break
-        else:
-            deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
-
-        return (DUMMY, None, None, None, *deserialized_grad_inputs)
-
-    @classmethod
-    @once_differentiable
-    def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
-
-        grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward_stream(
-                amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
-                    as_aiter(*split),
-                ),
-            )
-        )
 
-        return RemoteExpertWorker.run_coroutine(
-            gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+            expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
         )
 
-    @classmethod
-    @once_differentiable
-    def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
-        )
-
-        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+        return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 33 - 30
hivemind/moe/client/moe.py

@@ -1,20 +1,26 @@
 from __future__ import annotations
 
 import time
+from concurrent.futures import Future
 from queue import Empty, Queue
 from typing import Any, Dict, List, Optional, Tuple
 
-import grpc
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
+from hivemind.moe.client.expert import (
+    DUMMY,
+    RemoteExpert,
+    RemoteExpertWorker,
+    _get_expert_stub,
+    expert_backward,
+    expert_forward,
+)
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -95,7 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
         )
-
+        print(chosen_experts)
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
@@ -104,7 +110,7 @@ class RemoteMixtureOfExperts(nn.Module):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
@@ -177,8 +183,8 @@ class RemoteMixtureOfExperts(nn.Module):
         if self._expert_info is None:
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
-            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_scores_concat: torch.Tensor = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
+            dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
         return self._expert_info
@@ -223,15 +229,15 @@ class _RemoteCallMany(torch.autograd.Function):
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
         # dispatch tasks to all remote experts collect responses
-        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [
-                    serialize_torch_tensor(tensor, proto.compression)
-                    for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
-                ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.p2p, expert.server_peer_info)
-                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
+                stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
+                new_task = RemoteExpertWorker.run_coroutine(
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
+                    return_future=True,
+                )
                 pending_tasks[new_task] = (i, j)
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
@@ -316,14 +322,13 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
-            expert = expert_per_sample[i.item()][j.item()]
-            stub = _get_expert_stub(expert.endpoint)
+            expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
+            stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [
-                serialize_torch_tensor(tensor, proto.compression)
-                for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-            ]
-            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            compressions = (p.compression for p in backward_schema)
+            new_task = RemoteExpertWorker.run_coroutine(
+                expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
+            )
             pending_tasks[new_task] = (i, j)
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
@@ -358,7 +363,7 @@ class _RemoteCallMany(torch.autograd.Function):
 
     @staticmethod
     def _collect_responses(
-        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        task_to_indices: Dict[Future, Tuple[int, int]],
         num_samples: int,
         k_min: int,
         timeout_total: Optional[float],
@@ -408,17 +413,15 @@ class _RemoteCallMany(torch.autograd.Function):
         return finished_indices, finished_outputs
 
 
-def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
     if task.exception() or task.cancelled():
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
 
-    deserialized_outputs = []
-    for tensor in task.result().tensors:
-        deserialized_tensor = deserialize_torch_tensor(tensor)
-        if detect_anomalies and not deserialized_tensor.isfinite().all():
+    outputs = tuple(task.result())
+    for tensor in outputs:
+        if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             return None
-        deserialized_outputs.append(deserialized_tensor)
 
-    return tuple(deserialized_outputs)
+    return outputs

+ 2 - 6
hivemind/moe/client/switch_moe.py

@@ -2,12 +2,12 @@ from __future__ import annotations
 
 from typing import List, Tuple
 
-import grpc
 import torch
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -80,7 +80,6 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
         # Compute scores, find most appropriate experts with beam search
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
-
         grid_dropout_masks = (
             (
                 torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
@@ -96,12 +95,10 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             )
             for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
         ]
-
         grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
         )
-
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
@@ -110,9 +107,8 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
-
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             chosen_experts,

+ 2 - 2
hivemind/moe/server/server.py

@@ -24,7 +24,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.p2pd_pb2 import PeerInfo
+from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
@@ -309,7 +309,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
     try:
         runner.start()
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (dht_peer_id, dht_maddrs))
+        # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         if start_ok:
             yield data

+ 20 - 8
tests/test_custom_experts.py

@@ -4,6 +4,8 @@ import pytest
 import torch
 
 from hivemind import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -17,11 +19,16 @@ def test_custom_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
@@ -43,11 +50,16 @@ def test_multihead_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = (

+ 42 - 28
tests/test_moe.py

@@ -1,13 +1,16 @@
-import grpc
+import time
+
 import numpy as np
 import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
@@ -18,8 +21,8 @@ def test_moe():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
@@ -35,9 +38,8 @@ def test_no_experts():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
         dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
@@ -71,12 +73,16 @@ def test_call_many(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
+    ) as server_peer_info:
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
+            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            dht,
+        )
+        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
@@ -129,11 +135,15 @@ def test_remote_module_call(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = RemoteExpert("expert.0", server_endpoint)
-        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
-
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
@@ -144,9 +154,9 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             fake_expert(dummy_x)
 
 
@@ -154,11 +164,11 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
+    assert all(declare_experts(dht, all_expert_uids, dht.peer_id))
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
-    for i in range(25):
+    for _ in range(25):
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
@@ -173,7 +183,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -194,9 +204,12 @@ def test_determinism(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert = RemoteExpertWorker.spawn_experts(
+            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            dht=dht,
+        )[0]
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -220,7 +233,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
+                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
@@ -261,9 +274,10 @@ def test_client_anomaly_detection():
     server.start()
     try:
         server.ready.wait()
+        dht_experts = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=dht_experts, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
         input = torch.randn(1, 16)
@@ -280,7 +294,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=dht_experts, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)
         assert output.isfinite().all()

+ 17 - 11
tests/test_training.py

@@ -8,7 +8,8 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
 from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
@@ -19,12 +20,17 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
+    with background_server(
+        num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert1, expert2 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
         opt = SGD(model.parameters(), lr=0.05)
@@ -54,8 +60,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -107,8 +113,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)