expert.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import pickle
  2. from concurrent.futures import Future
  3. from queue import Queue
  4. from threading import Thread
  5. from typing import Any, Awaitable, Dict, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. import hivemind
  10. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  11. from hivemind.p2p import P2P, PeerInfo, StubBase
  12. from hivemind.proto import runtime_pb2
  13. from hivemind.utils import asingle, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
  14. DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
  15. def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
  16. return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
  17. class RemoteExpert(nn.Module):
  18. """
  19. A simple module that runs forward/backward of an expert hosted on a remote machine.
  20. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
  21. Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
  22. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
  23. :param uid: unique expert identifier
  24. """
  25. def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
  26. super().__init__()
  27. self.uid, self.server_peer_info = uid, server_peer_info
  28. self._info = None
  29. if p2p is None:
  30. self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
  31. else:
  32. self.p2p = p2p
  33. if connect:
  34. _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
  35. @property
  36. def stub(self) -> StubBase:
  37. return _get_expert_stub(self.p2p, self.server_peer_info)
  38. def forward(self, *args, **kwargs):
  39. """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
  40. assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
  41. kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
  42. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  43. forward_inputs = (args, kwargs)
  44. if not nested_compare(forward_inputs, self.info["forward_schema"]):
  45. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  46. flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
  47. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  48. return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
  49. @property
  50. def info(self):
  51. if self._info is None:
  52. outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
  53. self._info = pickle.loads(outputs.serialized_info)
  54. return self._info
  55. def extra_repr(self):
  56. return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
  57. class _RemoteModuleCall(torch.autograd.Function):
  58. """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
  59. _task_queue: Queue = Queue()
  60. _event_thread: Optional[Thread] = None
  61. @classmethod
  62. def _run(cls):
  63. loop = switch_to_uvloop()
  64. async def receive_tasks():
  65. while True:
  66. cor, future = cls._task_queue.get()
  67. try:
  68. result = await cor
  69. except Exception as e:
  70. future.set_exception(e)
  71. continue
  72. future.set_result(result)
  73. loop.run_until_complete(receive_tasks())
  74. @classmethod
  75. def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
  76. if cls._event_thread is None:
  77. cls._event_thread = Thread(target=cls._run, daemon=True)
  78. cls._event_thread.start()
  79. future = Future()
  80. cls._task_queue.put((coro, future))
  81. if return_future:
  82. return future
  83. result = future.result()
  84. return result
  85. @classmethod
  86. def forward(
  87. cls,
  88. ctx,
  89. dummy: torch.Tensor,
  90. uid: str,
  91. stub, #: ConnectionHandlerStub,
  92. info: Dict[str, Any],
  93. *inputs: torch.Tensor,
  94. ) -> Tuple[torch.Tensor, ...]:
  95. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  96. # detach to avoid pickling the computation graph
  97. inputs = tuple(tensor.cpu().detach() for tensor in inputs)
  98. ctx.uid, ctx.stub, ctx.info = uid, stub, info
  99. ctx.save_for_backward(*inputs)
  100. serialized_tensors = [
  101. serialize_torch_tensor(inp, proto.compression)
  102. for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
  103. ]
  104. outputs = cls.run_coroutine(
  105. asingle(
  106. stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
  107. ),
  108. )
  109. deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
  110. return tuple(deserialized_outputs)
  111. @classmethod
  112. @once_differentiable
  113. def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  114. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  115. inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
  116. backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
  117. serialized_tensors = [
  118. serialize_torch_tensor(tensor, proto.compression)
  119. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  120. ]
  121. grad_inputs = cls.run_coroutine(
  122. asingle(
  123. ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
  124. ),
  125. )
  126. deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
  127. return (DUMMY, None, None, None, *deserialized_grad_inputs)