瀏覽代碼

fix dht handler: add addrs to info for RemoteExpert init, add LazyValue construction on client side, because of problem with serialization of socket inside of RemoteExpert

Pavel Samygin 3 年之前
父節點
當前提交
2987c14fd0
共有 3 個文件被更改,包括 68 次插入11 次删除
  1. 24 11
      hivemind/moe/server/dht_handler.py
  2. 1 0
      hivemind/utils/__init__.py
  3. 43 0
      hivemind/utils/lazy_value.py

+ 24 - 11
hivemind/moe/server/dht_handler.py

@@ -1,6 +1,7 @@
 import threading
 import threading
 from functools import partial
 from functools import partial
-from typing import Dict, List, Optional, Sequence, Tuple
+from multiaddr import Multiaddr
+from typing import Union, Dict, List, Optional, Sequence, Tuple
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
@@ -15,7 +16,8 @@ from hivemind.moe.server.expert_uid import (
     split_uid,
     split_uid,
 )
 )
 from hivemind.p2p import PeerID, PeerInfo
 from hivemind.p2p import PeerID, PeerInfo
-from hivemind.utils import get_dht_time
+from hivemind.utils import get_dht_time, LazyFutureCaller, LazyValue
+from hivemind.utils.mpfuture import MPFuture
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
@@ -48,23 +50,24 @@ def declare_experts(
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
+    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), peer_id=peer_id, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=peer_id, addrs=addrs, expiration=expiration), return_future=not wait
     )
     )
 
 
 
 
 async def _declare_experts(
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: tuple[str], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
-        data_to_store[uid, None] = peer_id.to_base58()
+        data_to_store[uid, None] = (peer_id.to_base58(), addrs)
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, peer_id.to_base58()]
+            data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
 
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -73,7 +76,7 @@ async def _declare_experts(
 
 
 def get_experts(
 def get_experts(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> List[Optional[RemoteExpert]]:
+) -> Union[List[Optional[RemoteExpert]], LazyFutureCaller[Optional[LazyValue[RemoteExpert]], Optional[RemoteExpert]]]:
     """
     """
     :param uids: find experts with these ids from across the DHT
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -81,12 +84,19 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     :returns: a list of [RemoteExpert if found else None]
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+
+    def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
+        return [val.get() if val is not None else None for val in vals]
+
+    if return_future:
+        return LazyFutureCaller(result, _unwrap_experts)
+    return _unwrap_experts(result)
 
 
 
 
 async def _get_experts(
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpert]]:
+) -> List[Optional[LazyValue[RemoteExpert]]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
@@ -94,6 +104,9 @@ async def _get_experts(
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, PeerID):
-            experts[i] = RemoteExpert(uid, PeerInfo(peer_id=found[uid].value, addrs=[]))
+        if (elem := found[uid]) is not None and \
+            isinstance(elem.value, tuple):
+            peer_id, addrs = elem.value
+            peer_info = PeerInfo(peer_id=PeerID.from_base58(peer_id), addrs=tuple(Multiaddr(a) for a in addrs))
+            experts[i] = LazyValue(init=partial(RemoteExpert, uid=uid, server_peer_info=peer_info))
     return experts
     return experts

+ 1 - 0
hivemind/utils/__init__.py

@@ -9,3 +9,4 @@ from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *
+from hivemind.utils.lazy_value import LazyValue, LazyFutureCaller

+ 43 - 0
hivemind/utils/lazy_value.py

@@ -0,0 +1,43 @@
+from typing import Any, Generic, TypeVar, Callable, Optional, Union
+
+from hivemind.utils.mpfuture import MPFuture
+
+T = TypeVar("T")
+
+class _Empty(Generic[T]):
+
+    _instance = None
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super(_Empty, cls).__new__(cls, *args, **kwargs)
+        return cls._instance
+
+
+
+class LazyValue(Generic[T]):
+
+    def __init__(self, value: T = _Empty(), init: Optional[Callable[[], T]] = None):
+        assert value != _Empty() or init is not None, "One should provide either value or intializer"
+        self.value = value
+        self.init = init
+
+    def get(self) -> T:
+        if self.value == _Empty():
+            self.value = self.init()
+
+        return self.value
+
+RT = TypeVar("RT")
+
+class LazyFutureCaller(Generic[T, RT]):
+
+    def __init__(self, future: MPFuture[T], callback: Optional[Callable[[T], RT]] = None):
+        self._fut = future
+        self._cb = callback
+
+    def result(self) -> Union[T, RT]:
+        result = self._fut.result()
+        if self._cb is not None:
+            return self._cb(result)
+        return result