expert.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import pickle
  2. from typing import Any, Dict, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd.function import once_differentiable
  6. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  7. from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
  8. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  9. from hivemind.utils.grpc import ChannelCache
  10. DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
  11. def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
  12. """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
  13. channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
  14. return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
  15. class RemoteExpert(nn.Module):
  16. """
  17. A simple module that runs forward/backward of an expert hosted on a remote machine.
  18. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
  19. Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
  20. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
  21. :param uid: unique expert identifier
  22. :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
  23. """
  24. def __init__(self, uid, endpoint: Endpoint):
  25. super().__init__()
  26. self.uid, self.endpoint = uid, endpoint
  27. self._info = None
  28. @property
  29. def stub(self):
  30. return _get_expert_stub(self.endpoint)
  31. def forward(self, *args, **kwargs):
  32. """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
  33. assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
  34. kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
  35. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  36. forward_inputs = (args, kwargs)
  37. if not nested_compare(forward_inputs, self.info["forward_schema"]):
  38. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  39. flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
  40. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  41. return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
  42. @property
  43. def info(self):
  44. if self._info is None:
  45. outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
  46. self._info = pickle.loads(outputs.serialized_info)
  47. return self._info
  48. def extra_repr(self):
  49. return f"uid={self.uid}, endpoint={self.endpoint}"
  50. class _RemoteModuleCall(torch.autograd.Function):
  51. """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
  52. @staticmethod
  53. def forward(
  54. ctx,
  55. dummy: torch.Tensor,
  56. uid: str,
  57. stub: runtime_grpc.ConnectionHandlerStub,
  58. info: Dict[str, Any],
  59. *inputs: torch.Tensor,
  60. ) -> Tuple[torch.Tensor, ...]:
  61. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  62. # detach to avoid pickling the computation graph
  63. inputs = tuple(tensor.cpu().detach() for tensor in inputs)
  64. ctx.uid, ctx.stub, ctx.info = uid, stub, info
  65. ctx.save_for_backward(*inputs)
  66. serialized_tensors = [
  67. serialize_torch_tensor(inp, proto.compression)
  68. for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
  69. ]
  70. outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
  71. deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
  72. return tuple(deserialized_outputs)
  73. @staticmethod
  74. @once_differentiable
  75. def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  76. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  77. inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
  78. backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
  79. serialized_tensors = [
  80. serialize_torch_tensor(tensor, proto.compression)
  81. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  82. ]
  83. grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
  84. deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
  85. return (DUMMY, None, None, None, *deserialized_grad_inputs)