瀏覽代碼

add p2p from dht into RemoteExpert creation on client side

Pavel Samygin 3 年之前
父節點
當前提交
c5160465b5

二進制
hivemind/hivemind_cli/p2pd_old


二進制
hivemind/hivemind_cli/p2pd_old2


+ 9 - 5
hivemind/moe/client/beam_search.py

@@ -5,7 +5,7 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
@@ -258,12 +258,14 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
+
+        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
         if return_future:
             return LazyFutureCaller(
                 result,
-                lambda lst: [l.get() for l in lst]
+                lambda lst: [l.get(p2p=p2p) for l in lst]
             )
-        return [r.get() for r in result]
+        return [r.get(p2p=p2p) for r in result]
 
     @classmethod
     async def _find_best_experts(
@@ -390,9 +392,11 @@ class MoEBeamSearcher:
             return_future,
         )
 
+        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
+
         if return_future:
-            return LazyFutureCaller(result, lambda res: [[e.get() for e in exps] for exps in res])
-        return [[e.get() for e in exps] for exps in result]
+            return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
+        return [[e.get(p2p=p2p) for e in exps] for exps in result]
 
     @classmethod
     async def _batch_find_best_experts(

+ 1 - 3
hivemind/moe/client/expert.py

@@ -38,12 +38,10 @@ class RemoteExpert(nn.Module):
 
         if p2p is None:
             self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
+            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
         else:
             self.p2p = p2p
 
-        if connect:
-            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-
     @property
     def stub(self) -> StubBase:
         return _get_expert_stub(self.p2p, self.server_peer_info)

+ 3 - 2
hivemind/moe/server/dht_handler.py

@@ -4,7 +4,7 @@ from multiaddr import Multiaddr
 from typing import Union, Dict, List, Optional, Sequence, Tuple
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
@@ -83,10 +83,11 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    p2p = _RemoteModuleCall.run_coroutine(dht.replicate_p2p())
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
 
     def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
-        return [val.get() if val is not None else None for val in vals]
+        return [val.get(p2p=p2p) if val is not None else None for val in vals]
 
     if return_future:
         return LazyFutureCaller(result, _unwrap_experts)