Przeglądaj źródła

review issues fix

Pavel Samygin 3 lat temu
rodzic
commit
ca892e9bc6

+ 12 - 11
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server, layers
 from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
@@ -34,7 +34,8 @@ def print_device_info(device=None):
 def client_process(
     can_start,
     benchmarking_failed,
-    server_peer_info,
+    server_maddrs,
+    server_peer_id,
     num_experts,
     batch_size,
     hid_dim,
@@ -44,9 +45,13 @@ def client_process(
     torch.set_num_threads(1)
     can_start.wait()
 
-    p2p = RemoteExpertWorker.run_coroutine(P2P.create())
-    RemoteExpertWorker.run_coroutine(p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-    experts = [RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)]
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    experts = [
+        RemoteExpert(
+            expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=PeerInfo(server_peer_id, server_maddrs)), p2p=p2p
+        )
+        for i in range(num_experts)
+    ]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -86,11 +91,6 @@ def benchmark_throughput(
 
     try:
         server_dht = DHT(start=True)
-        server_dht_peer_info = PeerInfo(
-            peer_id=server_dht.peer_id,
-            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
-        )
-
         clients = [
             mp.Process(
                 target=client_process,
@@ -98,7 +98,8 @@ def benchmark_throughput(
                 args=(
                     can_start,
                     benchmarking_failed,
-                    server_dht_peer_info,
+                    server_dht.get_visible_maddrs(),
+                    server_dht.peer_id,
                     num_experts,
                     batch_size,
                     hid_dim,

+ 4 - 3
hivemind/moe/client/beam_search.py

@@ -17,6 +17,7 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     is_valid_prefix,
 )
+from hivemind.p2p import PeerInfo
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils.mpfuture import MPFuture
 
@@ -146,7 +147,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
-                        coord: UidEndpoint(*match.value)
+                        coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
@@ -213,7 +214,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
-                    coord: UidEndpoint(*match.value)
+                    coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size
@@ -329,7 +330,7 @@ class MoEBeamSearcher:
                 unique_experts.add(uid_endpoint.uid)
 
         best_experts = [
-            RemoteExpertInfo(uid_endpoint.uid, *uid_endpoint.endpoint)
+            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
         ]
         return best_experts

+ 37 - 49
hivemind/moe/client/expert.py

@@ -3,11 +3,10 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from queue import Queue
 from threading import Thread
-from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
+from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
 
 import torch
 import torch.nn as nn
-from multiaddr import Multiaddr
 from torch.autograd.function import once_differentiable
 
 import hivemind
@@ -15,7 +14,6 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.dht import DHT
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.proto import runtime_pb2
 from hivemind.utils import (
     MSGPackSerializer,
@@ -26,7 +24,7 @@ from hivemind.utils import (
     nested_pack,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
 from hivemind.utils.mpfuture import MPFuture
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -39,14 +37,7 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHand
 @dataclass(frozen=True)
 class RemoteExpertInfo:
     uid: str
-    peer_id: str
-    addrs: Sequence[str]
-
-    @property
-    def as_peer_info(self) -> Tuple[str, PeerInfo]:
-        return self.uid, PeerInfo(
-            peer_id=PeerID.from_base58(self.peer_id), addrs=tuple(Multiaddr(a) for a in self.addrs)
-        )
+    peer_info: PeerInfo
 
 
 class RemoteExpert(nn.Module):
@@ -58,10 +49,18 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     """
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
+    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
         super().__init__()
-        self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
-        self._info = None
+        self._info, self.p2p = expert_info, p2p
+        self._rpc_info = None
+
+    @property
+    def uid(self):
+        return self._info.uid
+
+    @property
+    def server_peer_info(self):
+        return self._info.peer_info
 
     @property
     def stub(self) -> StubBase:
@@ -86,10 +85,10 @@ class RemoteExpert(nn.Module):
 
     @property
     def info(self):
-        if self._info is None:
+        if self._rpc_info is None:
             outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
-            self._info = MSGPackSerializer.loads(outputs.serialized_info)
-        return self._info
+            self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+        return self._rpc_info
 
     def extra_repr(self):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
@@ -139,8 +138,7 @@ class RemoteExpertWorker:
         experts: List[Optional[RemoteExpert]] = []
         for i in infos:
             if i is not None:
-                uid, peer_info = i.as_peer_info
-                experts.append(RemoteExpert(uid, peer_info, p2p))
+                experts.append(RemoteExpert(i, p2p))
             else:
                 experts.append(None)
         return experts
@@ -195,16 +193,16 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(inp, proto.compression)
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
+        )
 
         size = 0
         for t in inputs:
             size += t.element_size() * t.nelement()
             if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_outputs = cls.forward_partial(serialized_tensors, ctx, stub)
+                deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
                 break
         else:
             deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
@@ -212,32 +210,27 @@ class _RemoteModuleCall(torch.autograd.Function):
         return tuple(deserialized_outputs)
 
     @classmethod
-    def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
-        split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
+    def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
+        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
         outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward_partial(
+            stub.rpc_forward_stream(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(
-                        uid=ctx.uid,
-                        tensors=[
-                            t,
-                        ],
-                    ),
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
                     as_aiter(*split),
                 ),
             )
         )
 
         return RemoteExpertWorker.run_coroutine(
-            gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+            gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
 
     @classmethod
-    def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
+    def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
 
         outputs = RemoteExpertWorker.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
         )
 
         return [deserialize_torch_tensor(t) for t in outputs.tensors]
@@ -248,16 +241,16 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
+        )
 
         size = 0
         for t in inputs_and_grad_outputs:
             size += t.element_size() * t.nelement()
             if size >= DEFAULT_MAX_MSG_SIZE:
-                deserialized_grad_inputs = cls.backward_partial(serialized_tensors, ctx)
+                deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
                 break
         else:
             deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
@@ -266,32 +259,27 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @classmethod
     @once_differentiable
-    def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
+    def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
         grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward_partial(
+            ctx.stub.rpc_backward_stream(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(
-                        uid=ctx.uid,
-                        tensors=[
-                            t,
-                        ],
-                    ),
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
                     as_aiter(*split),
                 ),
             )
         )
 
         return RemoteExpertWorker.run_coroutine(
-            gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+            gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
 
     @classmethod
     @once_differentiable
-    def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
+    def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         grad_inputs = RemoteExpertWorker.run_coroutine(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
         )
 
         return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 1 - 1
hivemind/moe/client/moe.py

@@ -230,7 +230,7 @@ class _RemoteCallMany(torch.autograd.Function):
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                 ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.p2p, expert.server_peer_info)
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
 

+ 6 - 14
hivemind/moe/server/connection_handler.py

@@ -13,7 +13,7 @@ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
@@ -78,7 +78,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor]]:
         unpacker = self._RequestUnpacker()
-        inputs = await gather_from_grpc(requests, unpacker, deserialize_torch_tensor)
+        inputs = await gather_from_rpc(requests, unpacker, deserialize_torch_tensor)
         return unpacker.uid, inputs
 
     async def _process_inputs(
@@ -99,7 +99,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
         )
 
-    async def rpc_forward_partial(
+    async def rpc_forward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         uid, inputs = await self._gather_inputs(requests, context)
@@ -111,11 +111,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         ]
 
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(
-                tensors=[
-                    part,
-                ],
-            )
+            yield runtime_pb2.ExpertResponse(tensors=[part])
 
     async def rpc_backward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
@@ -126,7 +122,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
         )
 
-    async def rpc_backward_partial(
+    async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
@@ -138,8 +134,4 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         ]
 
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(
-                tensors=[
-                    part,
-                ]
-            )
+            yield runtime_pb2.ExpertResponse(tensors=[part])

+ 2 - 3
hivemind/moe/server/dht_handler.py

@@ -14,7 +14,7 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     split_uid,
 )
-from hivemind.p2p import PeerID
+from hivemind.p2p import PeerID, PeerInfo
 from hivemind.utils import MPFuture, get_dht_time
 
 
@@ -101,6 +101,5 @@ async def _get_experts(
     for i, uid in enumerate(uids):
         elem = found[uid]
         if elem is not None and isinstance(elem.value, tuple):
-            peer_id, addrs = elem.value
-            experts[i] = RemoteExpertInfo(uid, peer_id, addrs)
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(elem.value))
     return experts

+ 2 - 2
hivemind/moe/server/expert_uid.py

@@ -1,10 +1,10 @@
 import re
 from typing import NamedTuple, Tuple, Union
 
-from hivemind.utils import Endpoint
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 3 - 3
hivemind/moe/server/server.py

@@ -24,8 +24,8 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.p2pd_pb2 import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import Endpoint
 from hivemind.utils.logging import get_logger
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
@@ -302,7 +302,7 @@ class Server(threading.Thread):
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
     """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
@@ -334,7 +334,7 @@ def _server_runner(pipe, *args, **kwargs):
 
     try:
         dht_maddrs = server.dht.get_visible_maddrs()
-        pipe.send((True, (server.dht.peer_id, dht_maddrs)))
+        pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
 
     finally:

+ 4 - 5
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,14 +5,13 @@ Author: Kevin Mai-Husan Chia
 """
 
 import hashlib
-from typing import Any, Sequence, Union
+from typing import Any, Sequence, Tuple, Union
 
 import base58
 import multihash
 from multiaddr import Multiaddr, protocols
 
 from hivemind.proto import p2pd_pb2
-from hivemind.utils import Endpoint
 
 # NOTE: On inlining...
 # See: https://github.com/libp2p/specs/issues/138
@@ -130,9 +129,9 @@ class PeerInfo:
         return PeerInfo(peer_id, addrs)
 
     @classmethod
-    def from_endpoint(cls, endpoint: Endpoint) -> "PeerInfo":
-        peer_id = PeerID.from_base58(endpoint[0])
-        addrs = [Multiaddr(a) for a in endpoint[1]]
+    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
+        peer_id = PeerID.from_base58(value[0])
+        addrs = [Multiaddr(addr) for addr in value[1]]
         return PeerInfo(peer_id, addrs)
 
     def __str__(self):

+ 0 - 1
hivemind/utils/__init__.py

@@ -4,7 +4,6 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor

+ 4 - 5
hivemind/utils/grpc.py

@@ -27,7 +27,6 @@ import torch
 
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
 from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 logger = get_logger(__name__)
@@ -45,7 +44,7 @@ GRPC_KEEPALIVE_OPTIONS = (
 
 
 class ChannelInfo(NamedTuple):
-    target: Endpoint
+    target: str
     aio: bool
     options: Tuple[Tuple[str, str], ...]
     credentials: Optional[grpc.ChannelCredentials]
@@ -90,7 +89,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     @classmethod
     def get_stub(
         cls,
-        target: Endpoint,
+        target: str,
         stub_type: Type[Stub],
         *,
         aio: bool,
@@ -137,7 +136,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     @classmethod
     def _create_channel(
         cls,
-        target: Endpoint,
+        target: str,
         aio: bool,
         extra_options: Tuple[Tuple[str, Any], ...],
         channel_credentials: Optional[grpc.ChannelCredentials],
@@ -228,7 +227,7 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
 RpcMessage = TypeVar("RpcMessage")
 
 
-async def gather_from_grpc(
+async def gather_from_rpc(
     stream: AsyncIterator[RpcMessage],
     key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],

+ 0 - 80
hivemind/utils/networking.py

@@ -1,80 +0,0 @@
-import socket
-from contextlib import closing
-from ipaddress import ip_address
-from typing import Optional, Sequence, Tuple
-
-from multiaddr import Multiaddr
-
-Hostname, Port = str, int  # flavour types
-Endpoint = (
-    Tuple[  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
-        str, Tuple[str, ...]
-    ]
-)
-LOCALHOST = "127.0.0.1"
-
-
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """
-    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
-
-    :note: Using this function is discouraged since it often leads to a race condition
-           with the "Address is already in use" error if the code is run in parallel.
-    """
-    try:
-        with closing(socket.socket(*params)) as sock:
-            sock.bind(("", 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception as e:
-        raise e
-
-
-def choose_ip_address(
-    maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
-) -> Hostname:
-    """
-    Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
-    To allow other peers reach a server when needed, these components announce a machine's IP address.
-
-    This function automatically selects the best IP address to announce among publicly visible multiaddrs
-    of this machine identified by libp2p (typically, using the ``P2P.get_visible_maddrs()`` method),
-    so a user does not need to define this address manually (unless the user wants to).
-
-    The best IP address is chosen using the following logic:
-      - Prefer IP addresses from global address blocks
-        (in terms of https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Address.is_global)
-      - Among the IP addresses of the same globality status, prefer IPv4 addresses over IPv6
-
-    If the default logic does not suit you, it is recommended to set the announced IP address manually.
-    """
-
-    for need_global in [prefer_global, not prefer_global]:
-        for protocol in protocol_priority:
-            for addr in maddrs:
-                if protocol in addr.protocols():
-                    value_for_protocol = addr[protocol]
-                    if ip_address(value_for_protocol).is_global == need_global:
-                        return value_for_protocol
-
-    raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")

+ 5 - 6
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind import LOCALHOST
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.p2p import PeerInfo
 
 
 @pytest.mark.forked
@@ -106,12 +106,11 @@ def test_dht_single_node():
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
 
-    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in node.get_visible_maddrs())
-    endpoint = (node.peer_id.to_base58(), addrs)
+    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
 
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", endpoint)
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", endpoint)
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", endpoint)
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
     assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)