Răsfoiți Sursa

review issues fix

Pavel Samygin 3 ani în urmă
părinte
comite
ca892e9bc6

+ 12 - 11
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server, layers
 from hivemind.moe.server import ExpertBackend, Server, layers
 from hivemind.p2p import P2P, PeerInfo
 from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
@@ -34,7 +34,8 @@ def print_device_info(device=None):
 def client_process(
 def client_process(
     can_start,
     can_start,
     benchmarking_failed,
     benchmarking_failed,
-    server_peer_info,
+    server_maddrs,
+    server_peer_id,
     num_experts,
     num_experts,
     batch_size,
     batch_size,
     hid_dim,
     hid_dim,
@@ -44,9 +45,13 @@ def client_process(
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
 
 
-    p2p = RemoteExpertWorker.run_coroutine(P2P.create())
-    RemoteExpertWorker.run_coroutine(p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-    experts = [RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)]
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    experts = [
+        RemoteExpert(
+            expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=PeerInfo(server_peer_id, server_maddrs)), p2p=p2p
+        )
+        for i in range(num_experts)
+    ]
 
 
     try:
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -86,11 +91,6 @@ def benchmark_throughput(
 
 
     try:
     try:
         server_dht = DHT(start=True)
         server_dht = DHT(start=True)
-        server_dht_peer_info = PeerInfo(
-            peer_id=server_dht.peer_id,
-            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
-        )
-
         clients = [
         clients = [
             mp.Process(
             mp.Process(
                 target=client_process,
                 target=client_process,
@@ -98,7 +98,8 @@ def benchmark_throughput(
                 args=(
                 args=(
                     can_start,
                     can_start,
                     benchmarking_failed,
                     benchmarking_failed,
-                    server_dht_peer_info,
+                    server_dht.get_visible_maddrs(),
+                    server_dht.peer_id,
                     num_experts,
                     num_experts,
                     batch_size,
                     batch_size,
                     hid_dim,
                     hid_dim,

+ 4 - 3
hivemind/moe/client/beam_search.py

@@ -17,6 +17,7 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
 )
 )
+from hivemind.p2p import PeerInfo
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.mpfuture import MPFuture
 
 
@@ -146,7 +147,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
                     successors = {
-                        coord: UidEndpoint(*match.value)
+                        coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
                         and isinstance(getattr(match, "value", None), list)
@@ -213,7 +214,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
                 successors[prefix] = {
-                    coord: UidEndpoint(*match.value)
+                    coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size
                     and 0 <= coord < grid_size
@@ -329,7 +330,7 @@ class MoEBeamSearcher:
                 unique_experts.add(uid_endpoint.uid)
                 unique_experts.add(uid_endpoint.uid)
 
 
         best_experts = [
         best_experts = [
-            RemoteExpertInfo(uid_endpoint.uid, *uid_endpoint.endpoint)
+            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
         ]
         ]
         return best_experts
         return best_experts

+ 37 - 49
hivemind/moe/client/expert.py

@@ -3,11 +3,10 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from dataclasses import dataclass
 from queue import Queue
 from queue import Queue
 from threading import Thread
 from threading import Thread
-from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
+from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
-from multiaddr import Multiaddr
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
@@ -15,7 +14,6 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils import (
 from hivemind.utils import (
     MSGPackSerializer,
     MSGPackSerializer,
@@ -26,7 +24,7 @@ from hivemind.utils import (
     nested_pack,
     nested_pack,
     switch_to_uvloop,
     switch_to_uvloop,
 )
 )
-from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.mpfuture import MPFuture
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -39,14 +37,7 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHand
 @dataclass(frozen=True)
 @dataclass(frozen=True)
 class RemoteExpertInfo:
 class RemoteExpertInfo:
     uid: str
     uid: str
-    peer_id: str
-    addrs: Sequence[str]
-
-    @property
-    def as_peer_info(self) -> Tuple[str, PeerInfo]:
-        return self.uid, PeerInfo(
-            peer_id=PeerID.from_base58(self.peer_id), addrs=tuple(Multiaddr(a) for a in self.addrs)
-        )
+    peer_info: PeerInfo
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
@@ -58,10 +49,18 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     :param uid: unique expert identifier
     """
     """
 
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
+    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
         super().__init__()
         super().__init__()
-        self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
-        self._info = None
+        self._info, self.p2p = expert_info, p2p
+        self._rpc_info = None
+
+    @property
+    def uid(self):
+        return self._info.uid
+
+    @property
+    def server_peer_info(self):
+        return self._info.peer_info
 
 
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
@@ -86,10 +85,10 @@ class RemoteExpert(nn.Module):
 
 
     @property
     @property
     def info(self):
     def info(self):
-        if self._info is None:
+        if self._rpc_info is None:
             outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
-            self._info = MSGPackSerializer.loads(outputs.serialized_info)
-        return self._info
+            self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+        return self._rpc_info
 
 
     def extra_repr(self):
     def extra_repr(self):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
@@ -139,8 +138,7 @@ class RemoteExpertWorker:
         experts: List[Optional[RemoteExpert]] = []
         experts: List[Optional[RemoteExpert]] = []
         for i in infos:
         for i in infos:
             if i is not None:
             if i is not None:
-                uid, peer_info = i.as_peer_info
-                experts.append(RemoteExpert(uid, peer_info, p2p))
+                experts.append(RemoteExpert(i, p2p))
             else:
             else:
                 experts.append(None)
                 experts.append(None)
         return experts
         return experts
@@ -195,16 +193,16 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(inp, proto.compression)
             serialize_torch_tensor(inp, proto.compression)
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
+        )
 
 
         size = 0
         size = 0
         for t in inputs:
         for t in inputs:
             size += t.element_size() * t.nelement()
             size += t.element_size() * t.nelement()
             if size >= DEFAULT_MAX_MSG_SIZE:
             if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_outputs = cls.forward_partial(serialized_tensors, ctx, stub)
+                deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
                 break
                 break
         else:
         else:
             deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
             deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
@@ -212,32 +210,27 @@ class _RemoteModuleCall(torch.autograd.Function):
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
     @classmethod
     @classmethod
-    def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-        split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
+    def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
+        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
 
         outputs = RemoteExpertWorker.run_coroutine(
         outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward_partial(
+            stub.rpc_forward_stream(
                 amap_in_executor(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(
-                        uid=ctx.uid,
-                        tensors=[
-                            t,
-                        ],
-                    ),
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
                     as_aiter(*split),
                     as_aiter(*split),
                 ),
                 ),
             )
             )
         )
         )
 
 
         return RemoteExpertWorker.run_coroutine(
         return RemoteExpertWorker.run_coroutine(
-            gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+            gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
         )
 
 
     @classmethod
     @classmethod
-    def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
+    def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
 
 
         outputs = RemoteExpertWorker.run_coroutine(
         outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
         )
         )
 
 
         return [deserialize_torch_tensor(t) for t in outputs.tensors]
         return [deserialize_torch_tensor(t) for t in outputs.tensors]
@@ -248,16 +241,16 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(tensor, proto.compression)
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
+        )
 
 
         size = 0
         size = 0
         for t in inputs_and_grad_outputs:
         for t in inputs_and_grad_outputs:
             size += t.element_size() * t.nelement()
             size += t.element_size() * t.nelement()
             if size >= DEFAULT_MAX_MSG_SIZE:
             if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_grad_inputs = cls.backward_partial(serialized_tensors, ctx)
+                deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
                 break
                 break
         else:
         else:
             deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
             deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
@@ -266,32 +259,27 @@ class _RemoteModuleCall(torch.autograd.Function):
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
-    def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
+    def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
 
         grad_inputs = RemoteExpertWorker.run_coroutine(
         grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward_partial(
+            ctx.stub.rpc_backward_stream(
                 amap_in_executor(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(
-                        uid=ctx.uid,
-                        tensors=[
-                            t,
-                        ],
-                    ),
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
                     as_aiter(*split),
                     as_aiter(*split),
                 ),
                 ),
             )
             )
         )
         )
 
 
         return RemoteExpertWorker.run_coroutine(
         return RemoteExpertWorker.run_coroutine(
-            gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+            gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
         )
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
-    def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
+    def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         grad_inputs = RemoteExpertWorker.run_coroutine(
         grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
         )
         )
 
 
         return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
         return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 1 - 1
hivemind/moe/client/moe.py

@@ -230,7 +230,7 @@ class _RemoteCallMany(torch.autograd.Function):
                     serialize_torch_tensor(tensor, proto.compression)
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                 ]
                 ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.p2p, expert.server_peer_info)
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
 
 

+ 6 - 14
hivemind/moe/server/connection_handler.py

@@ -13,7 +13,7 @@ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -78,7 +78,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor]]:
     ) -> Tuple[str, List[torch.Tensor]]:
         unpacker = self._RequestUnpacker()
         unpacker = self._RequestUnpacker()
-        inputs = await gather_from_grpc(requests, unpacker, deserialize_torch_tensor)
+        inputs = await gather_from_rpc(requests, unpacker, deserialize_torch_tensor)
         return unpacker.uid, inputs
         return unpacker.uid, inputs
 
 
     async def _process_inputs(
     async def _process_inputs(
@@ -99,7 +99,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
             tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
         )
         )
 
 
-    async def rpc_forward_partial(
+    async def rpc_forward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         uid, inputs = await self._gather_inputs(requests, context)
         uid, inputs = await self._gather_inputs(requests, context)
@@ -111,11 +111,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         ]
         ]
 
 
         async for part in as_aiter(*output_split):
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(
-                tensors=[
-                    part,
-                ],
-            )
+            yield runtime_pb2.ExpertResponse(tensors=[part])
 
 
     async def rpc_backward(
     async def rpc_backward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
@@ -126,7 +122,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
             tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
         )
         )
 
 
-    async def rpc_backward_partial(
+    async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
@@ -138,8 +134,4 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         ]
         ]
 
 
         async for part in as_aiter(*output_split):
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(
-                tensors=[
-                    part,
-                ]
-            )
+            yield runtime_pb2.ExpertResponse(tensors=[part])

+ 2 - 3
hivemind/moe/server/dht_handler.py

@@ -14,7 +14,7 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )
-from hivemind.p2p import PeerID
+from hivemind.p2p import PeerID, PeerInfo
 from hivemind.utils import MPFuture, get_dht_time
 from hivemind.utils import MPFuture, get_dht_time
 
 
 
 
@@ -101,6 +101,5 @@ async def _get_experts(
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
         elem = found[uid]
         elem = found[uid]
         if elem is not None and isinstance(elem.value, tuple):
         if elem is not None and isinstance(elem.value, tuple):
-            peer_id, addrs = elem.value
-            experts[i] = RemoteExpertInfo(uid, peer_id, addrs)
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(elem.value))
     return experts
     return experts

+ 2 - 2
hivemind/moe/server/expert_uid.py

@@ -1,10 +1,10 @@
 import re
 import re
 from typing import NamedTuple, Tuple, Union
 from typing import NamedTuple, Tuple, Union
 
 
-from hivemind.utils import Endpoint
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
 
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 3 - 3
hivemind/moe/server/server.py

@@ -24,8 +24,8 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
     schedule_name_to_scheduler,
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.p2pd_pb2 import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import Endpoint
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
@@ -302,7 +302,7 @@ class Server(threading.Thread):
 
 
 
 
 @contextmanager
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
     """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
     """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
@@ -334,7 +334,7 @@ def _server_runner(pipe, *args, **kwargs):
 
 
     try:
     try:
         dht_maddrs = server.dht.get_visible_maddrs()
         dht_maddrs = server.dht.get_visible_maddrs()
-        pipe.send((True, (server.dht.peer_id, dht_maddrs)))
+        pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
 
 
     finally:
     finally:

+ 4 - 5
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,14 +5,13 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import hashlib
 import hashlib
-from typing import Any, Sequence, Union
+from typing import Any, Sequence, Tuple, Union
 
 
 import base58
 import base58
 import multihash
 import multihash
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 
 
 from hivemind.proto import p2pd_pb2
 from hivemind.proto import p2pd_pb2
-from hivemind.utils import Endpoint
 
 
 # NOTE: On inlining...
 # NOTE: On inlining...
 # See: https://github.com/libp2p/specs/issues/138
 # See: https://github.com/libp2p/specs/issues/138
@@ -130,9 +129,9 @@ class PeerInfo:
         return PeerInfo(peer_id, addrs)
         return PeerInfo(peer_id, addrs)
 
 
     @classmethod
     @classmethod
-    def from_endpoint(cls, endpoint: Endpoint) -> "PeerInfo":
-        peer_id = PeerID.from_base58(endpoint[0])
-        addrs = [Multiaddr(a) for a in endpoint[1]]
+    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
+        peer_id = PeerID.from_base58(value[0])
+        addrs = [Multiaddr(addr) for addr in value[1]]
         return PeerInfo(peer_id, addrs)
         return PeerInfo(peer_id, addrs)
 
 
     def __str__(self):
     def __str__(self):

+ 0 - 1
hivemind/utils/__init__.py

@@ -4,7 +4,6 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor

+ 4 - 5
hivemind/utils/grpc.py

@@ -27,7 +27,6 @@ import torch
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
 from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -45,7 +44,7 @@ GRPC_KEEPALIVE_OPTIONS = (
 
 
 
 
 class ChannelInfo(NamedTuple):
 class ChannelInfo(NamedTuple):
-    target: Endpoint
+    target: str
     aio: bool
     aio: bool
     options: Tuple[Tuple[str, str], ...]
     options: Tuple[Tuple[str, str], ...]
     credentials: Optional[grpc.ChannelCredentials]
     credentials: Optional[grpc.ChannelCredentials]
@@ -90,7 +89,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     @classmethod
     @classmethod
     def get_stub(
     def get_stub(
         cls,
         cls,
-        target: Endpoint,
+        target: str,
         stub_type: Type[Stub],
         stub_type: Type[Stub],
         *,
         *,
         aio: bool,
         aio: bool,
@@ -137,7 +136,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     @classmethod
     @classmethod
     def _create_channel(
     def _create_channel(
         cls,
         cls,
-        target: Endpoint,
+        target: str,
         aio: bool,
         aio: bool,
         extra_options: Tuple[Tuple[str, Any], ...],
         extra_options: Tuple[Tuple[str, Any], ...],
         channel_credentials: Optional[grpc.ChannelCredentials],
         channel_credentials: Optional[grpc.ChannelCredentials],
@@ -228,7 +227,7 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
 RpcMessage = TypeVar("RpcMessage")
 RpcMessage = TypeVar("RpcMessage")
 
 
 
 
-async def gather_from_grpc(
+async def gather_from_rpc(
     stream: AsyncIterator[RpcMessage],
     stream: AsyncIterator[RpcMessage],
     key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
     key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],

+ 0 - 80
hivemind/utils/networking.py

@@ -1,80 +0,0 @@
-import socket
-from contextlib import closing
-from ipaddress import ip_address
-from typing import Optional, Sequence, Tuple
-
-from multiaddr import Multiaddr
-
-Hostname, Port = str, int  # flavour types
-Endpoint = (
-    Tuple[  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
-        str, Tuple[str, ...]
-    ]
-)
-LOCALHOST = "127.0.0.1"
-
-
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """
-    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
-
-    :note: Using this function is discouraged since it often leads to a race condition
-           with the "Address is already in use" error if the code is run in parallel.
-    """
-    try:
-        with closing(socket.socket(*params)) as sock:
-            sock.bind(("", 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception as e:
-        raise e
-
-
-def choose_ip_address(
-    maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
-) -> Hostname:
-    """
-    Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
-    To allow other peers reach a server when needed, these components announce a machine's IP address.
-
-    This function automatically selects the best IP address to announce among publicly visible multiaddrs
-    of this machine identified by libp2p (typically, using the ``P2P.get_visible_maddrs()`` method),
-    so a user does not need to define this address manually (unless the user wants to).
-
-    The best IP address is chosen using the following logic:
-      - Prefer IP addresses from global address blocks
-        (in terms of https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Address.is_global)
-      - Among the IP addresses of the same globality status, prefer IPv4 addresses over IPv6
-
-    If the default logic does not suit you, it is recommended to set the announced IP address manually.
-    """
-
-    for need_global in [prefer_global, not prefer_global]:
-        for protocol in protocol_priority:
-            for addr in maddrs:
-                if protocol in addr.protocols():
-                    value_for_protocol = addr[protocol]
-                    if ip_address(value_for_protocol).is_global == need_global:
-                        return value_for_protocol
-
-    raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")

+ 5 - 6
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind import LOCALHOST
 from hivemind.dht import DHTNode
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.p2p import PeerInfo
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -106,12 +106,11 @@ def test_dht_single_node():
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
     assert len(successors["e.1.2."]) == 2
 
 
-    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in node.get_visible_maddrs())
-    endpoint = (node.peer_id.to_base58(), addrs)
+    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
 
 
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", endpoint)
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", endpoint)
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", endpoint)
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
     assert successors["e.4.5."] == {}
     assert successors["e.4.5."] == {}
 
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)