Browse Source

add p2p from dht into RemoteExpert creation on client side

Pavel Samygin 3 years ago
parent
commit
c5160465b5

BIN
hivemind/hivemind_cli/p2pd_old


BIN
hivemind/hivemind_cli/p2pd_old2


+ 9 - 5
hivemind/moe/client/beam_search.py

@@ -5,7 +5,7 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
@@ -258,12 +258,14 @@ class MoEBeamSearcher:
             ),
             ),
             return_future,
             return_future,
         )
         )
+
+        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
         if return_future:
         if return_future:
             return LazyFutureCaller(
             return LazyFutureCaller(
                 result,
                 result,
-                lambda lst: [l.get() for l in lst]
+                lambda lst: [l.get(p2p=p2p) for l in lst]
             )
             )
-        return [r.get() for r in result]
+        return [r.get(p2p=p2p) for r in result]
 
 
     @classmethod
     @classmethod
     async def _find_best_experts(
     async def _find_best_experts(
@@ -390,9 +392,11 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
+        p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
+
         if return_future:
         if return_future:
-            return LazyFutureCaller(result, lambda res: [[e.get() for e in exps] for exps in res])
-        return [[e.get() for e in exps] for exps in result]
+            return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
+        return [[e.get(p2p=p2p) for e in exps] for exps in result]
 
 
     @classmethod
     @classmethod
     async def _batch_find_best_experts(
     async def _batch_find_best_experts(

+ 1 - 3
hivemind/moe/client/expert.py

@@ -38,12 +38,10 @@ class RemoteExpert(nn.Module):
 
 
         if p2p is None:
         if p2p is None:
             self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
             self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
+            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
         else:
         else:
             self.p2p = p2p
             self.p2p = p2p
 
 
-        if connect:
-            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
         return _get_expert_stub(self.p2p, self.server_peer_info)
         return _get_expert_stub(self.p2p, self.server_peer_info)

+ 3 - 2
hivemind/moe/server/dht_handler.py

@@ -4,7 +4,7 @@ from multiaddr import Multiaddr
 from typing import Union, Dict, List, Optional, Sequence, Tuple
 from typing import Union, Dict, List, Optional, Sequence, Tuple
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_DELIMITER,
@@ -83,10 +83,11 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     :returns: a list of [RemoteExpert if found else None]
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    p2p = _RemoteModuleCall.run_coroutine(dht.replicate_p2p())
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
 
 
     def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
     def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
-        return [val.get() if val is not None else None for val in vals]
+        return [val.get(p2p=p2p) if val is not None else None for val in vals]
 
 
     if return_future:
     if return_future:
         return LazyFutureCaller(result, _unwrap_experts)
         return LazyFutureCaller(result, _unwrap_experts)