Browse Source

fix moe and tests

Pavel Samygin 3 years ago
parent
commit
5ba4f72dba

+ 1 - 1
hivemind/compression/base.py

@@ -80,7 +80,7 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
     compression_type = runtime_pb2.CompressionType.NONE
 
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.numpy()
+        array = tensor.detach().numpy()
         return runtime_pb2.Tensor(
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             compression=self.compression_type,
             buffer=array.tobytes(),
             buffer=array.tobytes(),

+ 0 - 1
hivemind/dht/dht.py

@@ -310,7 +310,6 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         Get a replica of a P2P instance used in the DHT process internally.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
         """
-
         if self._p2p_replica is None or self._origin_pid != os.getpid():
         if self._p2p_replica is None or self._origin_pid != os.getpid():
             self._origin_pid = os.getpid()
             self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)

+ 1 - 1
hivemind/moe/client/beam_search.py

@@ -385,7 +385,7 @@ class MoEBeamSearcher:
             ),
             ),
             return_future,
             return_future,
         )
         )
-
+        print(result)
         if return_future:
         if return_future:
             return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
             return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
         return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)
         return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)

+ 80 - 82
hivemind/moe/client/expert.py

@@ -3,7 +3,7 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from dataclasses import dataclass
 from queue import Queue
 from queue import Queue
 from threading import Thread
 from threading import Thread
-from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -99,7 +99,7 @@ class RemoteExpertWorker:
 
 
     _task_queue: Queue = Queue()
     _task_queue: Queue = Queue()
     _event_thread: Optional[Thread] = None
     _event_thread: Optional[Thread] = None
-    _pid: int = 0
+    _pid: int = -1
 
 
     @classmethod
     @classmethod
     def _run(cls):
     def _run(cls):
@@ -113,7 +113,8 @@ class RemoteExpertWorker:
                 except Exception as e:
                 except Exception as e:
                     future.set_exception(e)
                     future.set_exception(e)
                     continue
                     continue
-                future.set_result(result)
+                if not future.cancelled():
+                    future.set_result(result)
 
 
         loop.run_until_complete(receive_tasks())
         loop.run_until_complete(receive_tasks())
 
 
@@ -151,7 +152,7 @@ class RemoteExpertWorker:
     @classmethod
     @classmethod
     def spawn_experts_future(
     def spawn_experts_future(
         cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
         cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> MPFuture[List[Optional[RemoteExpert]]]:
+    ) -> Future[List[Optional[RemoteExpert]]]:
         async def _unpack():
         async def _unpack():
             p2p = cls.run_coroutine(dht.replicate_p2p(), True)
             p2p = cls.run_coroutine(dht.replicate_p2p(), True)
             return cls.spawn_experts(await infos, await p2p)
             return cls.spawn_experts(await infos, await p2p)
@@ -166,7 +167,7 @@ class RemoteExpertWorker:
 
 
     @classmethod
     @classmethod
     def spawn_experts_bulk_future(
     def spawn_experts_bulk_future(
-        cls, infos: MPFuture[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
+        cls, infos: Future[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
     ) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
     ) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
         async def _unpack():
         async def _unpack():
             return cls.spawn_experts_bulk(await infos, dht)
             return cls.spawn_experts_bulk(await infos, dht)
@@ -174,6 +175,75 @@ class RemoteExpertWorker:
         return cls.run_coroutine(_unpack, True)
         return cls.run_coroutine(_unpack, True)
 
 
 
 
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            as_aiter(*split),
+        ),
+    )
+
+    return await gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+
+
+async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def expert_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
+) -> List[torch.Tensor]:
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
+    )
+
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size >= DEFAULT_MAX_MSG_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            as_aiter(*split),
+        ),
+    )
+
+    return await gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+
+
+async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
+    )
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size >= DEFAULT_MAX_MSG_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward(uid, serialized_tensors, stub)
+
+
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
 
@@ -193,93 +263,21 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
-        serialized_tensors = (
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(
+            expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
         )
         )
 
 
-        size = 0
-        for t in inputs:
-            size += t.element_size() * t.nelement()
-            if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
-                break
-        else:
-            deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
-
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
-    @classmethod
-    def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
-
-        outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward_stream(
-                amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
-                    as_aiter(*split),
-                ),
-            )
-        )
-
-        return RemoteExpertWorker.run_coroutine(
-            gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
-        )
-
-    @classmethod
-    def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-
-        outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
-        )
-
-        return [deserialize_torch_tensor(t) for t in outputs.tensors]
-
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
     def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
     def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = (
-            serialize_torch_tensor(tensor, proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        )
-
-        size = 0
-        for t in inputs_and_grad_outputs:
-            size += t.element_size() * t.nelement()
-            if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
-                break
-        else:
-            deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
-
-        return (DUMMY, None, None, None, *deserialized_grad_inputs)
-
-    @classmethod
-    @once_differentiable
-    def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
-
-        grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward_stream(
-                amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
-                    as_aiter(*split),
-                ),
-            )
-        )
 
 
-        return RemoteExpertWorker.run_coroutine(
-            gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+            expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
         )
         )
 
 
-    @classmethod
-    @once_differentiable
-    def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
-        )
-
-        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+        return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 33 - 30
hivemind/moe/client/moe.py

@@ -1,20 +1,26 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import time
 import time
+from concurrent.futures import Future
 from queue import Empty, Queue
 from queue import Empty, Queue
 from typing import Any, Dict, List, Optional, Tuple
 from typing import Any, Dict, List, Optional, Tuple
 
 
-import grpc
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
+from hivemind.moe.client.expert import (
+    DUMMY,
+    RemoteExpert,
+    RemoteExpertWorker,
+    _get_expert_stub,
+    expert_backward,
+    expert_forward,
+)
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -95,7 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
             [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
         )
         )
-
+        print(chosen_experts)
         if self._expert_info is None:
         if self._expert_info is None:
             try:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
@@ -104,7 +110,7 @@ class RemoteMixtureOfExperts(nn.Module):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                     "the grid size are consistent with running Server instances."
                 )
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
@@ -177,8 +183,8 @@ class RemoteMixtureOfExperts(nn.Module):
         if self._expert_info is None:
         if self._expert_info is None:
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             proj_device = self.proj.weight.device
-            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_scores_concat: torch.Tensor = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
+            dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
             self._expert_info = dummy_experts[0].info
         return self._expert_info
         return self._expert_info
@@ -223,15 +229,15 @@ class _RemoteCallMany(torch.autograd.Function):
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
 
         # dispatch tasks to all remote experts collect responses
         # dispatch tasks to all remote experts collect responses
-        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [
-                    serialize_torch_tensor(tensor, proto.compression)
-                    for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
-                ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.p2p, expert.server_peer_info)
-                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
+                stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
+                new_task = RemoteExpertWorker.run_coroutine(
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
+                    return_future=True,
+                )
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
 
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
         responded_inds, alive_flat_outputs = cls._collect_responses(
@@ -316,14 +322,13 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(
         for i, j, inputs_ij, grad_outputs_ij in zip(
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
         ):
-            expert = expert_per_sample[i.item()][j.item()]
-            stub = _get_expert_stub(expert.endpoint)
+            expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
+            stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [
-                serialize_torch_tensor(tensor, proto.compression)
-                for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-            ]
-            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            compressions = (p.compression for p in backward_schema)
+            new_task = RemoteExpertWorker.run_coroutine(
+                expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
+            )
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)
 
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
@@ -358,7 +363,7 @@ class _RemoteCallMany(torch.autograd.Function):
 
 
     @staticmethod
     @staticmethod
     def _collect_responses(
     def _collect_responses(
-        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        task_to_indices: Dict[Future, Tuple[int, int]],
         num_samples: int,
         num_samples: int,
         k_min: int,
         k_min: int,
         timeout_total: Optional[float],
         timeout_total: Optional[float],
@@ -408,17 +413,15 @@ class _RemoteCallMany(torch.autograd.Function):
         return finished_indices, finished_outputs
         return finished_indices, finished_outputs
 
 
 
 
-def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
     if task.exception() or task.cancelled():
     if task.exception() or task.cancelled():
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
         return None
 
 
-    deserialized_outputs = []
-    for tensor in task.result().tensors:
-        deserialized_tensor = deserialize_torch_tensor(tensor)
-        if detect_anomalies and not deserialized_tensor.isfinite().all():
+    outputs = tuple(task.result())
+    for tensor in outputs:
+        if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             return None
             return None
-        deserialized_outputs.append(deserialized_tensor)
 
 
-    return tuple(deserialized_outputs)
+    return outputs

+ 2 - 6
hivemind/moe/client/switch_moe.py

@@ -2,12 +2,12 @@ from __future__ import annotations
 
 
 from typing import List, Tuple
 from typing import List, Tuple
 
 
-import grpc
 import torch
 import torch
 
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -80,7 +80,6 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
 
         # Compute scores, find most appropriate experts with beam search
         # Compute scores, find most appropriate experts with beam search
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
-
         grid_dropout_masks = (
         grid_dropout_masks = (
             (
             (
                 torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
                 torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
@@ -96,12 +95,10 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             )
             )
             for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
             for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
         ]
         ]
-
         grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
         grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
             [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
         )
         )
-
         if self._expert_info is None:
         if self._expert_info is None:
             try:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
@@ -110,9 +107,8 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                     "the grid size are consistent with running Server instances."
                 )
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
-
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             DUMMY,
             chosen_experts,
             chosen_experts,

+ 2 - 2
hivemind/moe/server/server.py

@@ -24,7 +24,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
     schedule_name_to_scheduler,
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.p2pd_pb2 import PeerInfo
+from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
@@ -309,7 +309,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
     try:
     try:
         runner.start()
         runner.start()
         # once the server is ready, runner will send us
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (dht_peer_id, dht_maddrs))
+        # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         start_ok, data = pipe.recv()
         if start_ok:
         if start_ok:
             yield data
             yield data

+ 20 - 8
tests/test_custom_experts.py

@@ -4,6 +4,8 @@ import pytest
 import torch
 import torch
 
 
 from hivemind import RemoteExpert
 from hivemind import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -17,11 +19,16 @@ def test_custom_expert(hid_dim=16):
         device="cpu",
         device="cpu",
         hidden_dim=hid_dim,
         hidden_dim=hid_dim,
         num_handlers=2,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
 
         for batch_size in (1, 4):
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
             batch = torch.randn(batch_size, hid_dim)
@@ -43,11 +50,16 @@ def test_multihead_expert(hid_dim=16):
         device="cpu",
         device="cpu",
         hidden_dim=hid_dim,
         hidden_dim=hid_dim,
         num_handlers=2,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
 
         for batch_size in (1, 4):
         for batch_size in (1, 4):
             batch = (
             batch = (

+ 42 - 28
tests/test_moe.py

@@ -1,13 +1,16 @@
-import grpc
+import time
+
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 
 
@@ -18,8 +21,8 @@ def test_moe():
     ]
     ]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
 
@@ -35,9 +38,8 @@ def test_no_experts():
     ]
     ]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
         dmoe = RemoteSwitchMixtureOfExperts(
         dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             in_features=16,
             grid_size=(4, 4, 4),
             grid_size=(4, 4, 4),
@@ -71,12 +73,16 @@ def test_call_many(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
+    ) as server_peer_info:
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
+            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            dht,
+        )
+        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
 
         mask, expert_outputs = _RemoteCallMany.apply(
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             DUMMY,
@@ -129,11 +135,15 @@ def test_remote_module_call(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = RemoteExpert("expert.0", server_endpoint)
-        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
-
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         out1 = real_expert(torch.randn(1, hidden_dim))
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
         assert out1.shape == (1, hidden_dim)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
@@ -144,9 +154,9 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
 
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             real_expert(torch.randn(3, 11))
             real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             fake_expert(dummy_x)
             fake_expert(dummy_x)
 
 
 
 
@@ -154,11 +164,11 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
+    assert all(declare_experts(dht, all_expert_uids, dht.peer_id))
 
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
 
-    for i in range(25):
+    for _ in range(25):
         input = torch.randn(32)
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
 
@@ -173,7 +183,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
 
@@ -194,9 +204,12 @@ def test_determinism(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert = RemoteExpertWorker.spawn_experts(
+            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            dht=dht,
+        )[0]
 
 
         out = expert(xx, mask)
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -220,7 +233,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
         batch_experts = [
             [
             [
-                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
+                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
                 for expert_i in range(len(ii[batch_i]))
             ]
             ]
             for batch_i in range(len(ii))
             for batch_i in range(len(ii))
@@ -261,9 +274,10 @@ def test_client_anomaly_detection():
     server.start()
     server.start()
     try:
     try:
         server.ready.wait()
         server.ready.wait()
+        dht_experts = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
 
         dmoe = RemoteMixtureOfExperts(
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=dht_experts, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
         )
 
 
         input = torch.randn(1, 16)
         input = torch.randn(1, 16)
@@ -280,7 +294,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
             inf_loss.backward()
 
 
         dmoe = RemoteMixtureOfExperts(
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=dht_experts, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         )
         output = dmoe(input)
         output = dmoe(input)
         assert output.isfinite().all()
         assert output.isfinite().all()

+ 17 - 11
tests/test_training.py

@@ -8,7 +8,8 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 from sklearn.datasets import load_digits
 
 
 from hivemind import DHT
 from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
@@ -19,12 +20,17 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
     SGD = partial(torch.optim.SGD, lr=0.05)
 
 
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
+    with background_server(
+        num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert1, expert2 = RemoteExpertWorker.spawn_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
 
         opt = SGD(model.parameters(), lr=0.05)
         opt = SGD(model.parameters(), lr=0.05)
@@ -54,8 +60,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -107,8 +113,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)
         opt = SGD(model.parameters(), lr=0.05)