expert.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from __future__ import annotations
  2. import os
  3. from concurrent.futures import Future
  4. from dataclasses import dataclass
  5. from queue import Queue
  6. from threading import Thread
  7. from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
  8. import torch
  9. import torch.nn as nn
  10. from torch.autograd.function import once_differentiable
  11. import hivemind
  12. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  13. from hivemind.dht import DHT
  14. from hivemind.p2p import P2P, PeerInfo, StubBase
  15. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  16. from hivemind.proto import runtime_pb2
  17. from hivemind.utils import (
  18. MSGPackSerializer,
  19. amap_in_executor,
  20. iter_as_aiter,
  21. nested_compare,
  22. nested_flatten,
  23. nested_pack,
  24. switch_to_uvloop,
  25. )
  26. from hivemind.utils.mpfuture import MPFuture
  27. from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
  28. DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
  29. def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
  30. return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
  31. @dataclass(frozen=True)
  32. class RemoteExpertInfo:
  33. uid: str
  34. peer_info: PeerInfo
  35. class RemoteExpert(nn.Module):
  36. """
  37. A simple module that runs forward/backward of an expert hosted on a remote machine.
  38. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
  39. Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
  40. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
  41. :param uid: unique expert identifier
  42. """
  43. def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
  44. super().__init__()
  45. self._info, self.p2p = expert_info, p2p
  46. self._rpc_info = None
  47. @property
  48. def uid(self):
  49. return self._info.uid
  50. @property
  51. def server_peer_info(self):
  52. return self._info.peer_info
  53. @property
  54. def stub(self) -> StubBase:
  55. return _get_expert_stub(self.p2p, self.server_peer_info)
  56. def forward(self, *args, **kwargs):
  57. """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
  58. assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
  59. kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
  60. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  61. forward_inputs = (args, kwargs)
  62. if not nested_compare(forward_inputs, self.info["forward_schema"]):
  63. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  64. flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
  65. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  66. return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
  67. @property
  68. def info(self):
  69. if self._rpc_info is None:
  70. outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
  71. self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
  72. return self._rpc_info
  73. def extra_repr(self):
  74. return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
  75. class RemoteExpertWorker:
  76. """Local thread for managing async tasks related to RemoteExpert"""
  77. _task_queue: Queue = Queue()
  78. _event_thread: Optional[Thread] = None
  79. _pid: int = -1
  80. @classmethod
  81. def _run(cls):
  82. loop = switch_to_uvloop()
  83. async def receive_tasks():
  84. while True:
  85. cor, future = cls._task_queue.get()
  86. try:
  87. result = await cor
  88. except Exception as e:
  89. future.set_exception(e)
  90. continue
  91. if not future.cancelled():
  92. future.set_result(result)
  93. loop.run_until_complete(receive_tasks())
  94. @classmethod
  95. def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
  96. if cls._event_thread is None or cls._pid != os.getpid():
  97. cls._pid = os.getpid()
  98. cls._event_thread = Thread(target=cls._run, daemon=True)
  99. cls._event_thread.start()
  100. future = Future()
  101. cls._task_queue.put((coro, future))
  102. if return_future:
  103. return future
  104. result = future.result()
  105. return result
  106. @classmethod
  107. def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
  108. experts: List[Optional[RemoteExpert]] = []
  109. for i in infos:
  110. if i is not None:
  111. experts.append(RemoteExpert(i, p2p))
  112. else:
  113. experts.append(None)
  114. return experts
  115. @classmethod
  116. def spawn_experts(
  117. cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
  118. ) -> Union[List[Optional[RemoteExpert]], Future]:
  119. if return_future:
  120. async def _unpack(infos_future: MPFuture, dht: DHT):
  121. p2p = await dht.replicate_p2p()
  122. return cls._spawn_experts(await infos_future, p2p)
  123. return cls.run_coroutine(_unpack(infos, dht), return_future)
  124. p2p = cls.run_coroutine(dht.replicate_p2p())
  125. return cls._spawn_experts(infos, p2p)
  126. @classmethod
  127. def batch_spawn_experts(
  128. cls,
  129. infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
  130. dht: DHT,
  131. return_future: bool = False,
  132. ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
  133. if return_future:
  134. async def _unpack(infos_future: MPFuture, dht: DHT):
  135. p2p = await dht.replicate_p2p()
  136. return [cls._spawn_experts(i, p2p) for i in await infos_future]
  137. return cls.run_coroutine(_unpack(infos, dht), return_future)
  138. return [cls.spawn_experts(exps, dht) for exps in infos]
  139. async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
  140. split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
  141. grad_inputs = await stub.rpc_backward_stream(
  142. amap_in_executor(
  143. lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
  144. iter_as_aiter(split),
  145. ),
  146. )
  147. return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
  148. async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
  149. grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
  150. runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
  151. )
  152. return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
  153. async def expert_backward(
  154. uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
  155. ) -> List[torch.Tensor]:
  156. serialized_tensors = (
  157. serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
  158. )
  159. size = 0
  160. for t in inputs_and_grads:
  161. size += t.element_size() * t.nelement()
  162. if size >= DEFAULT_MAX_MSG_SIZE:
  163. return await _backward_stream(uid, serialized_tensors, stub)
  164. else:
  165. return await _backward_unary(uid, serialized_tensors, stub)
  166. async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
  167. split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
  168. outputs = await stub.rpc_forward_stream(
  169. amap_in_executor(
  170. lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
  171. iter_as_aiter(split),
  172. ),
  173. )
  174. return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
  175. async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
  176. outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
  177. runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
  178. )
  179. return [deserialize_torch_tensor(t) for t in outputs.tensors]
  180. async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
  181. serialized_tensors = (
  182. serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
  183. )
  184. size = 0
  185. for t in inputs:
  186. size += t.element_size() * t.nelement()
  187. if size >= DEFAULT_MAX_MSG_SIZE:
  188. return await _forward_stream(uid, serialized_tensors, stub)
  189. else:
  190. return await _forward_unary(uid, serialized_tensors, stub)
  191. class _RemoteModuleCall(torch.autograd.Function):
  192. """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
  193. @classmethod
  194. def forward(
  195. cls,
  196. ctx,
  197. dummy: torch.Tensor,
  198. uid: str,
  199. stub, #: ConnectionHandlerStub,
  200. info: Dict[str, Any],
  201. *inputs: torch.Tensor,
  202. ) -> Tuple[torch.Tensor, ...]:
  203. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  204. # detach to avoid pickling the computation graph
  205. inputs = tuple(tensor.cpu().detach() for tensor in inputs)
  206. ctx.uid, ctx.stub, ctx.info = uid, stub, info
  207. ctx.save_for_backward(*inputs)
  208. deserialized_outputs = RemoteExpertWorker.run_coroutine(
  209. expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
  210. )
  211. return tuple(deserialized_outputs)
  212. @classmethod
  213. @once_differentiable
  214. def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  215. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  216. inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
  217. backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
  218. deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
  219. expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
  220. )
  221. return (DUMMY, None, None, None, *deserialized_grad_inputs)