Bladeren bron

fix typing for backward comp, fix tests for p2p

Pavel Samygin 3 jaren geleden
bovenliggende
commit
4ff7c954f4

+ 26 - 8
hivemind/moe/client/beam_search.py

@@ -17,7 +17,8 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     is_valid_prefix,
 )
-from hivemind.utils import MPFuture, get_dht_time, get_logger
+from hivemind.p2p import PeerInfo
+from hivemind.utils import get_dht_time, get_logger, LazyFutureCaller, LazyValue
 
 logger = get_logger(__name__)
 
@@ -230,7 +231,7 @@ class MoEBeamSearcher:
 
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+    ) -> Union[List[RemoteExpert], LazyFutureCaller]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -245,7 +246,7 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._find_best_experts,
                 prefix=self.uid_prefix,
@@ -257,6 +258,12 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
+        if return_future:
+            return LazyFutureCaller(
+                result,
+                lambda lst: [l.get() for l in lst]
+            )
+        return [r.get() for r in result]
 
     @classmethod
     async def _find_best_experts(
@@ -269,7 +276,7 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpert]:
+    ) -> List[LazyValue[RemoteExpert]]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -322,7 +329,14 @@ class MoEBeamSearcher:
                 push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
                 unique_experts.add(uid_endpoint.uid)
 
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
+        best_experts = [
+            LazyValue(init=partial(
+                RemoteExpert,
+                uid=uid_endpoint.uid,
+                server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
+            ))
+            for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
+        ]
         return best_experts
 
     @staticmethod
@@ -351,7 +365,7 @@ class MoEBeamSearcher:
 
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], MPFuture]:
+    ) -> Union[List[List[RemoteExpert]], LazyFutureCaller]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -364,7 +378,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._batch_find_best_experts,
                 prefix=self.uid_prefix,
@@ -376,6 +390,10 @@ class MoEBeamSearcher:
             return_future,
         )
 
+        if return_future:
+            return LazyFutureCaller(result, lambda res: [[e.get() for e in exps] for exps in res])
+        return [[e.get() for e in exps] for exps in result]
+
     @classmethod
     async def _batch_find_best_experts(
         cls,
@@ -386,7 +404,7 @@ class MoEBeamSearcher:
         beam_size: int,
         negative_caching: bool,
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpert]]:
+    ) -> Sequence[Sequence[LazyValue[RemoteExpert]]]:
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]

+ 9 - 10
hivemind/moe/client/expert.py

@@ -1,8 +1,7 @@
 from concurrent.futures import Future
 from queue import Queue
 from threading import Thread
-from typing import Any, Awaitable, Dict, Optional, Tuple
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Awaitable, Dict, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -148,8 +147,8 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @classmethod
     def forward_partial(
-        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
-    ) -> list[torch.Tensor]:
+        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
+    ) -> List[torch.Tensor]:
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
 
         outputs = cls.run_coroutine(
@@ -167,8 +166,8 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @classmethod
     def forward_oneshot(
-        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
-    ) -> list[torch.Tensor]:
+        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
+    ) -> List[torch.Tensor]:
 
         outputs = cls.run_coroutine(
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
@@ -201,8 +200,8 @@ class _RemoteModuleCall(torch.autograd.Function):
     @classmethod
     @once_differentiable
     def backward_partial(
-        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
-    ) -> list[torch.Tensor]:
+        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
+    ) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
         grad_inputs = cls.run_coroutine(
@@ -221,8 +220,8 @@ class _RemoteModuleCall(torch.autograd.Function):
     @classmethod
     @once_differentiable
     def backward_oneshot(
-        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
-    ) -> list[torch.Tensor]:
+        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
+    ) -> List[torch.Tensor]:
         grad_inputs = cls.run_coroutine(
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import multiprocessing as mp
-from typing import AsyncIterator, Dict, Iterable, Union
+from typing import AsyncIterator, Dict, Iterable, Union, Tuple, List
 
 import torch
 
@@ -76,13 +76,13 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> tuple[str, list[torch.Tensor]]:
+    ) -> Tuple[str, List[torch.Tensor]]:
         unpacker = self._RequestUnpacker()
         inputs = await gather_from_grpc(requests, unpacker, deserialize_torch_tensor)
         return unpacker.uid, inputs
 
     async def _process_inputs(
-        self, inputs: list[torch.Tensor], pool: TaskPool, schema: Union[BatchTensorDescriptor, tuple[BatchTensorDescriptor, ...]]
+        self, inputs: List[torch.Tensor], pool: TaskPool, schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]]
     ):
         return [
             serialize_torch_tensor(t, p.compression, allow_inplace=True)

+ 1 - 2
hivemind/moe/server/dht_handler.py

@@ -17,7 +17,6 @@ from hivemind.moe.server.expert_uid import (
 )
 from hivemind.p2p import PeerID, PeerInfo
 from hivemind.utils import get_dht_time, LazyFutureCaller, LazyValue
-from hivemind.utils.mpfuture import MPFuture
 
 
 class DHTHandlerThread(threading.Thread):
@@ -57,7 +56,7 @@ def declare_experts(
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: tuple[str], expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration

+ 7 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -12,6 +12,7 @@ import multihash
 from multiaddr import Multiaddr, protocols
 
 from hivemind.proto import p2pd_pb2
+from hivemind.utils import Endpoint
 
 # NOTE: On inlining...
 # See: https://github.com/libp2p/specs/issues/138
@@ -128,6 +129,12 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
 
+    @classmethod
+    def from_endpoint(cls, endpoint: Endpoint) -> "PeerInfo":
+        peer_id = PeerID.from_base58(endpoint[0])
+        addrs = [Multiaddr(a) for a in endpoint[1]]
+        return PeerInfo(peer_id, addrs)
+
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 

+ 2 - 2
hivemind/utils/grpc.py

@@ -7,7 +7,7 @@ from __future__ import annotations
 import os
 import threading
 import torch
-from typing import Callable, AsyncIterator, Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
+from typing import Callable, AsyncIterator, Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 import grpc
 
@@ -217,7 +217,7 @@ async def gather_from_grpc(
     stream: AsyncIterator[RpcMessage],
     key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
-) -> list[torch.Tensor]:
+) -> List[torch.Tensor]:
     tensors = []
     parts = []
 

+ 3 - 3
hivemind/utils/lazy_value.py

@@ -17,14 +17,14 @@ class _Empty(Generic[T]):
 
 class LazyValue(Generic[T]):
 
-    def __init__(self, value: T = _Empty(), init: Optional[Callable[[], T]] = None):
+    def __init__(self, value: T = _Empty(), init: Optional[Callable[..., T]] = None):
         assert value != _Empty() or init is not None, "One should provide either value or intializer"
         self.value = value
         self.init = init
 
-    def get(self) -> T:
+    def get(self, *args, **kwargs) -> T:
         if self.value == _Empty():
-            self.value = self.init()
+            self.value = self.init(*args, **kwargs)
 
         return self.value
 

+ 2 - 2
hivemind/utils/networking.py

@@ -1,12 +1,12 @@
 import socket
 from contextlib import closing
 from ipaddress import ip_address
-from typing import Optional, Sequence
+from typing import Optional, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
 Hostname, Port = str, int  # flavour types
-Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
+Endpoint = Tuple[str, Tuple[str, ...]]  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 
 

+ 23 - 18
tests/test_dht_experts.py

@@ -25,17 +25,18 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], first_peer.peer_id)
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
-    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
+    other_expert = "my_other_expert.1337"
+    declare_experts(other_peer, [other_expert], other_peer.peer_id)
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f"that_host:{other_port}"
+    assert first_found.server_peer_info.peer_id == other_peer.peer_id
+    assert first_notfound is None
 
     # test graceful shutdown
     first_peer.shutdown()
@@ -43,8 +44,8 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"], remaining_peer1.peer_id))
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
 
 
 @pytest.mark.forked
@@ -59,11 +60,11 @@ def test_beam_search(
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
     )
     for batch_start in range(0, len(real_experts), batch_size):
+        dht_ = random.choice(dht)
         declare_experts(
-            random.choice(dht),
+            dht_,
             real_experts[batch_start : batch_start + batch_size],
-            wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
+            peer_id=dht_.peer_id,
         )
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
@@ -89,22 +90,26 @@ def test_dht_single_node():
     node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], node.peer_id).values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"], node.peer_id)) == 4
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], node.peer_id)) == 7
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+        assert expert.server_peer_info.peer_id == node.peer_id
 
-    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ["expert.5", "expert.2"], node.peer_id).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
+
+    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in node.get_visible_maddrs())
+    endpoint = (node.peer_id.to_base58(), addrs)
+
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", endpoint)
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", endpoint)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", endpoint)
     assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
@@ -194,7 +199,7 @@ async def test_negative_caching(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], writer_peer.peer_id).values())
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)