|
@@ -5,7 +5,7 @@ from functools import partial
|
|
|
from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
|
|
from hivemind.dht import DHT, DHTExpiration, DHTNode
|
|
|
-from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
|
|
|
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
|
|
|
from hivemind.moe.server.expert_uid import (
|
|
|
FLAT_EXPERT,
|
|
|
PREFIX_PATTERN,
|
|
@@ -17,8 +17,8 @@ from hivemind.moe.server.expert_uid import (
|
|
|
UidEndpoint,
|
|
|
is_valid_prefix,
|
|
|
)
|
|
|
-from hivemind.p2p import PeerInfo
|
|
|
-from hivemind.utils import LazyFutureCaller, LazyValue, get_dht_time, get_logger
|
|
|
+from hivemind.utils import get_dht_time, get_logger
|
|
|
+from hivemind.utils.mpfuture import MPFuture
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -231,7 +231,7 @@ class MoEBeamSearcher:
|
|
|
|
|
|
def find_best_experts(
|
|
|
self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
|
|
|
- ) -> Union[List[RemoteExpert], LazyFutureCaller]:
|
|
|
+ ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
|
|
|
"""
|
|
|
Find and return :beam_size: active experts with highest scores, use both local cache and DHT
|
|
|
|
|
@@ -259,11 +259,10 @@ class MoEBeamSearcher:
|
|
|
return_future,
|
|
|
)
|
|
|
|
|
|
- p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
|
|
|
if return_future:
|
|
|
- return LazyFutureCaller(result, lambda lst: [l.get(p2p=p2p) for l in lst])
|
|
|
+ return RemoteExpertWorker.spawn_experts_future(result, self.dht)
|
|
|
|
|
|
- return [r.get(p2p=p2p) for r in result]
|
|
|
+ return RemoteExpertWorker.spawn_experts(result, self.dht)
|
|
|
|
|
|
@classmethod
|
|
|
async def _find_best_experts(
|
|
@@ -276,7 +275,7 @@ class MoEBeamSearcher:
|
|
|
negative_caching: bool,
|
|
|
cache_expiration: DHTExpiration,
|
|
|
num_workers: Optional[int] = None,
|
|
|
- ) -> List[LazyValue[RemoteExpert]]:
|
|
|
+ ) -> List[RemoteExpertInfo]:
|
|
|
num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
|
|
|
|
|
|
# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
|
|
@@ -330,13 +329,7 @@ class MoEBeamSearcher:
|
|
|
unique_experts.add(uid_endpoint.uid)
|
|
|
|
|
|
best_experts = [
|
|
|
- LazyValue(
|
|
|
- init=partial(
|
|
|
- RemoteExpert,
|
|
|
- uid=uid_endpoint.uid,
|
|
|
- server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
|
|
|
- )
|
|
|
- )
|
|
|
+ RemoteExpertInfo(uid_endpoint.uid, *uid_endpoint.endpoint)
|
|
|
for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
|
|
|
]
|
|
|
return best_experts
|
|
@@ -367,7 +360,7 @@ class MoEBeamSearcher:
|
|
|
|
|
|
def batch_find_best_experts(
|
|
|
self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
|
|
|
- ) -> Union[List[List[RemoteExpert]], LazyFutureCaller]:
|
|
|
+ ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
|
|
|
"""
|
|
|
Find and return :beam_size: active experts with highest scores, use both local cache and DHT
|
|
|
|
|
@@ -392,11 +385,10 @@ class MoEBeamSearcher:
|
|
|
return_future,
|
|
|
)
|
|
|
|
|
|
- p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
|
|
|
|
|
|
if return_future:
|
|
|
- return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
|
|
|
- return [[e.get(p2p=p2p) for e in exps] for exps in result]
|
|
|
+ return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
|
|
|
+ return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)
|
|
|
|
|
|
@classmethod
|
|
|
async def _batch_find_best_experts(
|
|
@@ -408,7 +400,7 @@ class MoEBeamSearcher:
|
|
|
beam_size: int,
|
|
|
negative_caching: bool,
|
|
|
num_workers: Optional[int],
|
|
|
- ) -> Sequence[Sequence[LazyValue[RemoteExpert]]]:
|
|
|
+ ) -> Sequence[Sequence[RemoteExpertInfo]]:
|
|
|
batch_grid_scores = [
|
|
|
[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
|
|
|
]
|