Browse Source

fix beam search

Pavel Samygin 3 years ago
parent
commit
a9f99c20da
2 changed files with 44 additions and 32 deletions
  1. 12 20
      hivemind/moe/client/beam_search.py
  2. 32 12
      hivemind/moe/client/expert.py

+ 12 - 20
hivemind/moe/client/beam_search.py

@@ -5,7 +5,7 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
@@ -17,8 +17,8 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     is_valid_prefix,
 )
-from hivemind.p2p import PeerInfo
-from hivemind.utils import LazyFutureCaller, LazyValue, get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.mpfuture import MPFuture
 
 logger = get_logger(__name__)
 
@@ -231,7 +231,7 @@ class MoEBeamSearcher:
 
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], LazyFutureCaller]:
+    ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -259,11 +259,10 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
         if return_future:
-            return LazyFutureCaller(result, lambda lst: [l.get(p2p=p2p) for l in lst])
+            return RemoteExpertWorker.spawn_experts_future(result, self.dht)
 
-        return [r.get(p2p=p2p) for r in result]
+        return RemoteExpertWorker.spawn_experts(result, self.dht)
 
     @classmethod
     async def _find_best_experts(
@@ -276,7 +275,7 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[LazyValue[RemoteExpert]]:
+    ) -> List[RemoteExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -330,13 +329,7 @@ class MoEBeamSearcher:
                 unique_experts.add(uid_endpoint.uid)
 
         best_experts = [
-            LazyValue(
-                init=partial(
-                    RemoteExpert,
-                    uid=uid_endpoint.uid,
-                    server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
-                )
-            )
+            RemoteExpertInfo(uid_endpoint.uid, *uid_endpoint.endpoint)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
         ]
         return best_experts
@@ -367,7 +360,7 @@ class MoEBeamSearcher:
 
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], LazyFutureCaller]:
+    ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -392,11 +385,10 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
 
         if return_future:
-            return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
-        return [[e.get(p2p=p2p) for e in exps] for exps in result]
+            return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
+        return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)
 
     @classmethod
     async def _batch_find_best_experts(
@@ -408,7 +400,7 @@ class MoEBeamSearcher:
         beam_size: int,
         negative_caching: bool,
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[LazyValue[RemoteExpert]]]:
+    ) -> Sequence[Sequence[RemoteExpertInfo]]:
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]

+ 32 - 12
hivemind/moe/client/expert.py

@@ -1,7 +1,6 @@
 import os
 from concurrent.futures import Future
 from dataclasses import dataclass
-from lib2to3.pgen2.token import OP
 from queue import Queue
 from threading import Thread
 from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
@@ -136,17 +135,7 @@ class RemoteExpertWorker:
         return result
 
     @classmethod
-    def spawn_experts_future(
-        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> MPFuture[List[Optional[RemoteExpert]]]:
-        async def _unpack():
-            return cls.spawn_experts(await infos, dht)
-
-        return cls.run_coroutine(_unpack, True)
-
-    @classmethod
-    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
-        p2p = cls.run_coroutine(dht.replicate_p2p())
+    def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
         experts: List[Optional[RemoteExpert]] = []
         for i in infos:
             if i is not None:
@@ -156,6 +145,37 @@ class RemoteExpertWorker:
                 experts.append(None)
         return experts
 
+    @classmethod
+    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
+        p2p = cls.run_coroutine(dht.replicate_p2p())
+        return cls._spawn_experts(infos, p2p)
+
+    @classmethod
+    def spawn_experts_future(
+        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
+    ) -> MPFuture[List[Optional[RemoteExpert]]]:
+        async def _unpack():
+            p2p = cls.run_coroutine(dht.replicate_p2p(), True)
+            return cls.spawn_experts(await infos, await p2p)
+
+        return cls.run_coroutine(_unpack, True)
+
+    @classmethod
+    def spawn_experts_bulk(
+        cls, infos: Sequence[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
+    ) -> List[List[Optional[RemoteExpert]]]:
+        return [cls.spawn_experts(exps, dht) for exps in infos]
+
+    @classmethod
+    def spawn_experts_bulk_future(
+        cls, infos: MPFuture[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
+    ) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
+        async def _unpack():
+            return cls.spawn_experts_bulk(await infos, dht)
+
+        return cls.run_coroutine(_unpack, True)
+
+
 
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""