浏览代码

fix beam search

Pavel Samygin 3 年之前
父节点
当前提交
a9f99c20da
共有 2 个文件被更改,包括 44 次插入32 次删除
  1. 12 20
      hivemind/moe/client/beam_search.py
  2. 32 12
      hivemind/moe/client/expert.py

+ 12 - 20
hivemind/moe/client/beam_search.py

@@ -5,7 +5,7 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
@@ -17,8 +17,8 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
 )
 )
-from hivemind.p2p import PeerInfo
-from hivemind.utils import LazyFutureCaller, LazyValue, get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.mpfuture import MPFuture
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -231,7 +231,7 @@ class MoEBeamSearcher:
 
 
     def find_best_experts(
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], LazyFutureCaller]:
+    ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
         """
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
 
@@ -259,11 +259,10 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
-        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
         if return_future:
         if return_future:
-            return LazyFutureCaller(result, lambda lst: [l.get(p2p=p2p) for l in lst])
+            return RemoteExpertWorker.spawn_experts_future(result, self.dht)
 
 
-        return [r.get(p2p=p2p) for r in result]
+        return RemoteExpertWorker.spawn_experts(result, self.dht)
 
 
     @classmethod
     @classmethod
     async def _find_best_experts(
     async def _find_best_experts(
@@ -276,7 +275,7 @@ class MoEBeamSearcher:
         negative_caching: bool,
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
-    ) -> List[LazyValue[RemoteExpert]]:
+    ) -> List[RemoteExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -330,13 +329,7 @@ class MoEBeamSearcher:
                 unique_experts.add(uid_endpoint.uid)
                 unique_experts.add(uid_endpoint.uid)
 
 
         best_experts = [
         best_experts = [
-            LazyValue(
-                init=partial(
-                    RemoteExpert,
-                    uid=uid_endpoint.uid,
-                    server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
-                )
-            )
+            RemoteExpertInfo(uid_endpoint.uid, *uid_endpoint.endpoint)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
         ]
         ]
         return best_experts
         return best_experts
@@ -367,7 +360,7 @@ class MoEBeamSearcher:
 
 
     def batch_find_best_experts(
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], LazyFutureCaller]:
+    ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
         """
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
 
@@ -392,11 +385,10 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
-        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
 
 
         if return_future:
         if return_future:
-            return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
-        return [[e.get(p2p=p2p) for e in exps] for exps in result]
+            return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
+        return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)
 
 
     @classmethod
     @classmethod
     async def _batch_find_best_experts(
     async def _batch_find_best_experts(
@@ -408,7 +400,7 @@ class MoEBeamSearcher:
         beam_size: int,
         beam_size: int,
         negative_caching: bool,
         negative_caching: bool,
         num_workers: Optional[int],
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[LazyValue[RemoteExpert]]]:
+    ) -> Sequence[Sequence[RemoteExpertInfo]]:
         batch_grid_scores = [
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]
         ]

+ 32 - 12
hivemind/moe/client/expert.py

@@ -1,7 +1,6 @@
 import os
 import os
 from concurrent.futures import Future
 from concurrent.futures import Future
 from dataclasses import dataclass
 from dataclasses import dataclass
-from lib2to3.pgen2.token import OP
 from queue import Queue
 from queue import Queue
 from threading import Thread
 from threading import Thread
 from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
 from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
@@ -136,17 +135,7 @@ class RemoteExpertWorker:
         return result
         return result
 
 
     @classmethod
     @classmethod
-    def spawn_experts_future(
-        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> MPFuture[List[Optional[RemoteExpert]]]:
-        async def _unpack():
-            return cls.spawn_experts(await infos, dht)
-
-        return cls.run_coroutine(_unpack, True)
-
-    @classmethod
-    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
-        p2p = cls.run_coroutine(dht.replicate_p2p())
+    def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
         experts: List[Optional[RemoteExpert]] = []
         experts: List[Optional[RemoteExpert]] = []
         for i in infos:
         for i in infos:
             if i is not None:
             if i is not None:
@@ -156,6 +145,37 @@ class RemoteExpertWorker:
                 experts.append(None)
                 experts.append(None)
         return experts
         return experts
 
 
+    @classmethod
+    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
+        p2p = cls.run_coroutine(dht.replicate_p2p())
+        return cls._spawn_experts(infos, p2p)
+
+    @classmethod
+    def spawn_experts_future(
+        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
+    ) -> MPFuture[List[Optional[RemoteExpert]]]:
+        async def _unpack():
+            p2p = cls.run_coroutine(dht.replicate_p2p(), True)
+            return cls.spawn_experts(await infos, await p2p)
+
+        return cls.run_coroutine(_unpack, True)
+
+    @classmethod
+    def spawn_experts_bulk(
+        cls, infos: Sequence[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
+    ) -> List[List[Optional[RemoteExpert]]]:
+        return [cls.spawn_experts(exps, dht) for exps in infos]
+
+    @classmethod
+    def spawn_experts_bulk_future(
+        cls, infos: MPFuture[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
+    ) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
+        async def _unpack():
+            return cls.spawn_experts_bulk(await infos, dht)
+
+        return cls.run_coroutine(_unpack, True)
+
+
 
 
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""