remote_expert.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Tuple, Optional
  2. import torch
  3. import torch.nn as nn
  4. from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
  5. class RemoteExpert(nn.Module):
  6. """
  7. A simple module that runs forward/backward of an expert hosted on a remote machine.
  8. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
  9. Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
  10. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
  11. :param uid: unique expert identifier
  12. :param host: hostname where TesseractServer operates
  13. :param port: port to which TesseractServer listens
  14. """
  15. def __init__(self, uid, host='127.0.0.1', port=8080):
  16. super().__init__()
  17. self.uid, self.host, self.port = uid, host, port
  18. self._info = None
  19. def forward(self, *args, **kwargs):
  20. """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
  21. assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
  22. kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
  23. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  24. forward_inputs = (args, kwargs)
  25. if not nested_compare(forward_inputs, self.info['forward_schema']):
  26. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  27. flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs))
  28. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  29. return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
  30. @property
  31. def info(self):
  32. if self._info is None:
  33. connection = Connection.create(self.host, self.port)
  34. connection.send_raw('info', PytorchSerializer.dumps(self.uid))
  35. self._info = PytorchSerializer.loads(connection.recv_message()[1])
  36. return self._info
  37. def extra_repr(self):
  38. return f"uid={self.uid}, host={self.host}, port={self.port}"
  39. class _RemoteModuleCall(torch.autograd.Function):
  40. """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
  41. @staticmethod
  42. def forward(ctx, dummy: torch.Tensor,
  43. uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  44. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  45. inputs = tuple(map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph
  46. ctx.uid, ctx.host, ctx.port = uid, host, port
  47. ctx.save_for_backward(*inputs)
  48. connection = Connection.create(ctx.host, ctx.port)
  49. connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs)))
  50. rtype, msg = connection.recv_message()
  51. assert len(msg) != 0, "ExpertBackend.forward did not respond"
  52. return tuple(PytorchSerializer.loads(msg)) # flattened expert outputs
  53. @staticmethod
  54. def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  55. connection = Connection.create(ctx.host, ctx.port)
  56. payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
  57. connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload)))
  58. rtype, msg = connection.recv_message()
  59. assert len(msg) != 0, "ExpertBackend.backward did not respond"
  60. grad_inputs = PytorchSerializer.loads(msg)
  61. return (DUMMY, None, None, None, *grad_inputs)