remote_forward_backward.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """
  2. Utility functions that call RPC forward or backward on a single remote server
  3. """
  4. import asyncio
  5. from typing import Iterable, List, Sequence, Tuple
  6. import torch
  7. from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
  8. from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
  9. from hivemind.p2p import StubBase
  10. from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
  11. from hivemind.proto import runtime_pb2
  12. from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
  13. from hivemind.utils.streaming import split_for_streaming
  14. from src.data_structures import ModuleUID, RPCInfo
  15. async def _forward_unary(
  16. uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
  17. ) -> List[torch.Tensor]:
  18. outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
  19. runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
  20. timeout=timeout,
  21. )
  22. return [deserialize_torch_tensor(t) for t in outputs.tensors]
  23. async def _backward_unary(
  24. uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
  25. ) -> List[torch.Tensor]:
  26. grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
  27. runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
  28. timeout=timeout,
  29. )
  30. return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
  31. async def _forward_stream(
  32. uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
  33. ) -> List[torch.Tensor]:
  34. parts = (
  35. runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
  36. for tensor in serialized_tensors
  37. for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  38. )
  39. outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
  40. outputs = aiter_with_timeout(outputs, timeout)
  41. return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
  42. async def _backward_stream(
  43. uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
  44. ) -> List[torch.Tensor]:
  45. parts = (
  46. runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
  47. for tensor in serialized_tensors
  48. for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  49. )
  50. grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
  51. grad_inputs = aiter_with_timeout(grad_inputs, timeout)
  52. return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
  53. async def run_remote_forward(
  54. uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
  55. ) -> Tuple[torch.Tensor, ...]:
  56. """
  57. Serializes input tensors and calls "rpc_forward" on a remote server.
  58. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
  59. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  60. """
  61. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  62. # detach to avoid pickling the computation graph
  63. assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
  64. kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
  65. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  66. forward_inputs = (inputs, kwargs)
  67. # Modify forward_schema to support prompts
  68. args_schema, kwargs_schema = rpc_info["forward_schema"]
  69. # TODO: rm this assert when support arbitrary number of input tensors
  70. assert len(args_schema) == 1 and len(inputs) == 2
  71. forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
  72. if not nested_compare(forward_inputs, forward_schema_with_prompts):
  73. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  74. forward_inputs = nested_flatten(forward_inputs)
  75. inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
  76. # Asynchronous serialization
  77. loop = asyncio.get_running_loop()
  78. serialized_tensors = await asyncio.gather(
  79. *(
  80. loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
  81. for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
  82. )
  83. )
  84. # call RPC on remote server
  85. size = sum(t.element_size() * t.nelement() for t in inputs)
  86. if size > MAX_UNARY_PAYLOAD_SIZE:
  87. deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
  88. else:
  89. deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
  90. return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
  91. async def run_remote_backward(
  92. uid: ModuleUID,
  93. stub: StubBase,
  94. rpc_info: RPCInfo,
  95. inputs: torch.Tensor,
  96. grad_outputs: List[torch.Tensor],
  97. *extra_tensors: torch.Tensor,
  98. timeout: float,
  99. **kwargs,
  100. ) -> Sequence[torch.Tensor]:
  101. """
  102. Serializes grad outputs and calls "rpc_backward" on a remote server.
  103. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
  104. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  105. """
  106. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  107. inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
  108. # Modify forward_schema to support prompts
  109. args_schema, kwargs_schema = rpc_info["forward_schema"]
  110. assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
  111. # TODO generalize this
  112. prompts_schema = next(iter(args_schema))
  113. backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
  114. # Asynchronous serialization
  115. loop = asyncio.get_running_loop()
  116. serialized_tensors = await asyncio.gather(
  117. *(
  118. loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
  119. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  120. )
  121. )
  122. size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
  123. if size > MAX_UNARY_PAYLOAD_SIZE:
  124. deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
  125. else:
  126. deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
  127. return deserialized_grad_inputs