remote_block.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from concurrent.futures import Future
  2. from functools import partial
  3. from typing import List, Optional, Union, Sequence
  4. import torch
  5. from hivemind.moe.client import RemoteExpert
  6. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  7. from hivemind.moe.expert_uid import ExpertUID
  8. from hivemind.moe.server.dht_handler import _get_experts
  9. from hivemind.p2p import StubBase, P2P
  10. from hivemind.proto.runtime_pb2 import ExpertInfo
  11. from hivemind.dht import DHT
  12. from hivemind.utils import MPFuture, DHTExpiration
  13. from src.server.handler import TransformerConnectionHandler
  14. class RemoteTransformerBlock(RemoteExpert):
  15. """A class that interacts with a specific remote server for forward/backward or inference"""
  16. def __init__(self, info: ExpertInfo, p2p: P2P):
  17. super().__init__(info, p2p)
  18. # self._config = config
  19. # self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
  20. # self._active_stream: Optional[RemoteTransformerStream] = None
  21. @property
  22. def stub(self) -> StubBase:
  23. return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  24. def get_remote_module(
  25. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  26. ) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
  27. """
  28. :param uids: find experts with these ids from across the DHT
  29. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  30. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  31. :returns: a list of [RemoteTransformerBlock if found else None]
  32. """
  33. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  34. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  35. return create_remote_module(result, dht, return_future)
  36. def create_remote_module(
  37. infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
  38. ) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
  39. if return_future:
  40. async def _unpack(infos_future: MPFuture, dht: DHT):
  41. p2p = await dht.replicate_p2p()
  42. return _create_remote_experts(await infos_future, p2p)
  43. return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
  44. p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
  45. return _create_remote_experts(infos, p2p)
  46. def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
  47. experts: List[Optional[RemoteTransformerBlock]] = []
  48. for info in infos:
  49. if info is not None:
  50. experts.append(RemoteTransformerBlock(info, p2p))
  51. else:
  52. experts.append(None)
  53. return experts