expert_backend.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from typing import Dict, Sequence, Any, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from .task_pool import TaskPool
  5. from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map
  6. class ExpertBackend(nn.Module):
  7. """
  8. ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with TesseractRuntime
  9. By default, ExpertBackend handles three types of requests:
  10. - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
  11. - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
  12. - get_info - return expert metadata. Not batched.
  13. :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
  14. - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
  15. - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
  16. - We recommend using experts that are ~invariant to the order in which they process batches
  17. :param opt: torch optimizer to be applied on every backward call
  18. :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
  19. :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
  20. :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
  21. :param kwargs: extra parameters to be forwarded into TaskPool.__init__
  22. """
  23. def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
  24. args_schema: Tuple[BatchTensorProto, ...] = None,
  25. kwargs_schema: Dict[str, BatchTensorProto] = None,
  26. outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
  27. **kwargs):
  28. super().__init__()
  29. self.expert, self.opt, self.name = expert, opt, name
  30. self.args_schema = args_schema = tuple(args_schema or ())
  31. self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
  32. assert args_schema or kwargs_schema, "expert must receive at least one positional or keyword input." \
  33. " Did you forget to provide args_schema/kwargs_schema?"
  34. if outputs_schema is None:
  35. # run expert once to get outputs schema
  36. dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
  37. dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
  38. dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
  39. outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs)
  40. self.outputs_schema = outputs_schema
  41. self.forward_schema = (self.args_schema, self.kwargs_schema)
  42. self.backward_schema = (self.forward_schema, self.outputs_schema) # original inputs and grad w.r.t. outputs
  43. self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
  44. self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
  45. def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  46. """
  47. Apply forward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually;
  48. To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
  49. Subclassing:
  50. This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
  51. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
  52. .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
  53. """
  54. args, kwargs = nested_pack(inputs, structure=self.forward_schema)
  55. with torch.no_grad():
  56. outputs = self.expert(*args, **kwargs)
  57. # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
  58. return tuple(nested_flatten(outputs))
  59. def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  60. """
  61. Apply backward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually
  62. To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
  63. Subclassing:
  64. This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
  65. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
  66. TesseractRuntime doesn't guarantee that backward will be performed in the same order and for the same data
  67. as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
  68. .. todo state, randomness, etc
  69. Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
  70. """
  71. (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
  72. with torch.enable_grad():
  73. args = [tensor.detach().requires_grad_(True) for tensor in args]
  74. kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
  75. outputs = self.expert(*args, **kwargs)
  76. assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
  77. outputs_flat = tuple(nested_flatten(outputs))
  78. grad_outputs_flat = tuple(map(
  79. lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
  80. nested_flatten(grad_outputs), outputs_flat))
  81. torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
  82. create_graph=False, retain_graph=False)
  83. self.apply_gradients()
  84. return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
  85. for x in nested_flatten((args, kwargs)))
  86. def apply_gradients(self) -> None:
  87. """
  88. Train the expert for a single step. This method is called by ``ExpertBackend.backward`` after computing gradients.
  89. """
  90. self.opt.step()
  91. self.opt.zero_grad()
  92. def get_info(self) -> Dict[str, Any]:
  93. """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
  94. return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
  95. keyword_names=tuple(self.kwargs_schema.keys()))
  96. def get_pools(self) -> Sequence[TaskPool]:
  97. """ return all pools that should be processed by ``TesseractRuntime`` """
  98. return self.forward_pool, self.backward_pool