remote_block.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from concurrent.futures import Future
  2. from functools import partial
  3. from typing import List, Optional, Union, Sequence
  4. from hivemind.moe.client import RemoteExpert
  5. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  6. from hivemind.moe.expert_uid import ExpertUID
  7. from hivemind.moe.server.dht_handler import _get_experts
  8. from hivemind.p2p import StubBase, P2P
  9. from hivemind.proto.runtime_pb2 import ExpertInfo
  10. from hivemind.dht import DHTExpiration, DHT
  11. from hivemind.utils import MPFuture
  12. from src.server.handler import TransformerConnectionHandler
  13. class RemoteTransformerBlock(RemoteExpert):
  14. @property
  15. def stub(self) -> StubBase:
  16. return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  17. def get_remote_module(
  18. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  19. ) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
  20. """
  21. :param uids: find experts with these ids from across the DHT
  22. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  23. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  24. :returns: a list of [RemoteTransformerBlock if found else None]
  25. """
  26. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  27. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  28. return create_remote_module(result, dht, return_future)
  29. def create_remote_module(
  30. infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
  31. ) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
  32. if return_future:
  33. async def _unpack(infos_future: MPFuture, dht: DHT):
  34. p2p = await dht.replicate_p2p()
  35. return _create_remote_experts(await infos_future, p2p)
  36. return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
  37. p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
  38. return _create_remote_experts(infos, p2p)
  39. def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
  40. experts: List[Optional[RemoteTransformerBlock]] = []
  41. for info in infos:
  42. if info is not None:
  43. experts.append(RemoteTransformerBlock(info, p2p))
  44. else:
  45. experts.append(None)
  46. return experts