remote_block.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
  2. from __future__ import annotations
  3. import asyncio
  4. import random
  5. from typing import Any, AsyncIterator, Dict, Optional
  6. import torch
  7. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  8. from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
  9. from hivemind.moe.expert_uid import ExpertInfo
  10. from hivemind.p2p import P2P, StubBase
  11. from hivemind.proto import runtime_pb2
  12. from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
  13. from src.data_structures import RemoteModuleInfo, RPCInfo
  14. from src.dht_utils import ModuleUID
  15. from src.server.handler import TransformerConnectionHandler
  16. use_hivemind_log_handler("in_root_logger")
  17. logger = get_logger(__file__)
  18. class RemoteTransformerBlock(RemoteExpert):
  19. """A class that interacts with a remote module on a specific server for forward/backward or inference"""
  20. def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
  21. peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this
  22. super().__init__(peer_info, p2p)
  23. @property
  24. def stub(self) -> StubBase:
  25. return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  26. def forward(self, inputs: torch.Tensor, **kwargs):
  27. for k, v in kwargs.items():
  28. assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
  29. return super().forward(inputs)
  30. def inference_session(self) -> RemoteTransformerBlockInferenceSession:
  31. """Initialize a new inference session with the specified remote server"""
  32. return RemoteExpertWorker.run_coroutine(
  33. RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
  34. )
  35. def begin_inference_session(self):
  36. logger.warning("beging_inference_session was renamed to just inference_session")
  37. return self.inference_session()
  38. class RemoteTransformerBlockInferenceSession:
  39. """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
  40. def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
  41. self.uid, self.rpc_info = uid, rpc_info
  42. # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
  43. # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
  44. self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
  45. self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
  46. self.stepped = False
  47. self.closed = False
  48. @classmethod
  49. async def _create(
  50. cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
  51. ) -> RemoteTransformerBlockInferenceSession:
  52. """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
  53. inputs_queue = asyncio.Queue()
  54. outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
  55. return cls(uid, rpc_info, inputs_queue, outputs_stream)
  56. @staticmethod
  57. async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
  58. while True:
  59. next_input_message = await asyncio.wait_for(queue.get(), timeout)
  60. yield next_input_message
  61. if not next_input_message.uid and not next_input_message.tensors:
  62. break # this message means "done sending"
  63. def step(self, new_hidden_states: torch.Tensor):
  64. """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
  65. if self.closed:
  66. raise Exception("Session is closed, cannot perform step")
  67. # serialize inputs and put them into the queue
  68. inputs = (new_hidden_states,)
  69. outputs_serialized = RemoteExpertWorker.run_coroutine(
  70. self._step(
  71. runtime_pb2.ExpertRequest(
  72. uid=self.uid,
  73. tensors=[
  74. serialize_torch_tensor(tensor, proto.compression)
  75. for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
  76. ],
  77. )
  78. )
  79. )
  80. outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
  81. assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
  82. return outputs[0]
  83. async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
  84. """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
  85. await self._inputs_queue.put(inputs_serialized)
  86. self.stepped = True
  87. return await anext(self._outputs_stream)
  88. def close(self):
  89. """Finish a given inference session, close the underlying connection"""
  90. if self._outputs_stream is None:
  91. return # already closed
  92. RemoteExpertWorker.run_coroutine(self._aclose_stream())
  93. self._outputs_stream = self._inputs_queue = None
  94. self.closed = True
  95. async def _aclose_stream(self):
  96. """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
  97. if self._outputs_stream is None:
  98. return # already closed
  99. if self.stepped:
  100. await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
  101. try:
  102. await anext(self._outputs_stream)
  103. except StopAsyncIteration:
  104. pass
  105. def __del__(self):
  106. self.close()
  107. def __enter__(self):
  108. assert not self.closed
  109. return self
  110. def __exit__(self, *exc_details):
  111. self.close()