remote_block.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from concurrent.futures import Future
  2. from functools import partial
  3. from typing import List, Optional, Union, Sequence
  4. import torch
  5. from hivemind.moe.client import RemoteExpert
  6. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  7. from hivemind.moe.expert_uid import ExpertUID
  8. from hivemind.moe.server.dht_handler import _get_experts
  9. from hivemind.p2p import StubBase, P2P
  10. from hivemind.proto.runtime_pb2 import ExpertInfo
  11. from hivemind.dht import DHT
  12. from hivemind.utils import MPFuture, DHTExpiration
  13. from src import DistributedBloomConfig
  14. from src.server.backend import MAX_LENGTH
  15. from src.server.handler import TransformerConnectionHandler
  16. class RemoteTransformerBlockSession(RemoteExpert):
  17. """A class that interacts with a specific remote server for forward/backward or inference"""
  18. def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
  19. super().__init__(info, p2p)
  20. self._config = config
  21. self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
  22. self._active_stream: Optional[RemoteTransformerStream] = None
  23. @property
  24. def stub(self) -> StubBase:
  25. return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  26. def get_remote_module(
  27. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  28. ) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
  29. """
  30. :param uids: find experts with these ids from across the DHT
  31. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  32. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  33. :returns: a list of [RemoteTransformerBlock if found else None]
  34. """
  35. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  36. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  37. return create_remote_module(result, dht, return_future)
  38. def create_remote_module(
  39. infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
  40. ) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
  41. if return_future:
  42. async def _unpack(infos_future: MPFuture, dht: DHT):
  43. p2p = await dht.replicate_p2p()
  44. return _create_remote_experts(await infos_future, p2p)
  45. return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
  46. p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
  47. return _create_remote_experts(infos, p2p)
  48. def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
  49. experts: List[Optional[RemoteTransformerBlockSession]] = []
  50. for info in infos:
  51. if info is not None:
  52. experts.append(RemoteTransformerBlockSession(info, p2p))
  53. else:
  54. experts.append(None)
  55. return experts