expert.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Tuple, Optional
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd.function import once_differentiable
  5. from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
  6. class RemoteExpert(nn.Module):
  7. """
  8. A simple module that runs forward/backward of an expert hosted on a remote machine.
  9. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
  10. Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
  11. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
  12. :param uid: unique expert identifier
  13. :param host: hostname where TesseractServer operates
  14. :param port: port to which TesseractServer listens
  15. """
  16. def __init__(self, uid, host='127.0.0.1', port=8080):
  17. super().__init__()
  18. self.uid, self.host, self.port = uid, host, port
  19. self._info = None
  20. def forward(self, *args, **kwargs):
  21. """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
  22. assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
  23. kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
  24. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  25. forward_inputs = (args, kwargs)
  26. if not nested_compare(forward_inputs, self.info['forward_schema']):
  27. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  28. flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs))
  29. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  30. return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
  31. @property
  32. def info(self):
  33. if self._info is None:
  34. connection = Connection.create(self.host, self.port)
  35. connection.send_raw('info', PytorchSerializer.dumps(self.uid))
  36. self._info = PytorchSerializer.loads(connection.recv_message()[1])
  37. return self._info
  38. def extra_repr(self):
  39. return f"uid={self.uid}, host={self.host}, port={self.port}"
  40. class _RemoteModuleCall(torch.autograd.Function):
  41. """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
  42. @staticmethod
  43. def forward(ctx, dummy: torch.Tensor,
  44. uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  45. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  46. inputs = tuple(map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph
  47. ctx.uid, ctx.host, ctx.port = uid, host, port
  48. ctx.save_for_backward(*inputs)
  49. connection = Connection.create(ctx.host, ctx.port)
  50. connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs)))
  51. rtype, msg = connection.recv_message()
  52. assert len(msg) != 0, "ExpertBackend.forward did not respond"
  53. return tuple(PytorchSerializer.loads(msg)) # flattened expert outputs
  54. @staticmethod
  55. @once_differentiable
  56. def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  57. connection = Connection.create(ctx.host, ctx.port)
  58. payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
  59. connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload)))
  60. rtype, msg = connection.recv_message()
  61. assert len(msg) != 0, "ExpertBackend.backward did not respond"
  62. grad_inputs = PytorchSerializer.loads(msg)
  63. return (DUMMY, None, None, None, *grad_inputs)