expert_backend.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from typing import Dict, Sequence, Any, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from hivemind.server.task_pool import TaskPool
  5. from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map,\
  6. BatchTensorDescriptor, DUMMY_BATCH_SIZE
  7. class ExpertBackend(nn.Module):
  8. """
  9. ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
  10. By default, ExpertBackend handles three types of requests:
  11. - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
  12. - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
  13. - get_info - return expert metadata. Not batched.
  14. :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
  15. - Experts must always receive the same set of args and kwargs and produce output tensors of same type
  16. - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
  17. - We recommend using experts that are ~invariant to the order in which they process batches
  18. - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
  19. you should explicitly register these random variables as model inputs or outputs.
  20. See hivemind.utils.custom_layers.DeterministicDropout for an example
  21. :param opt: torch optimizer to be applied on every backward call
  22. :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
  23. :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
  24. :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
  25. :param kwargs: extra parameters to be forwarded into TaskPool.__init__
  26. """
  27. def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
  28. args_schema: Tuple[BatchTensorDescriptor, ...] = None,
  29. kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
  30. outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
  31. **kwargs):
  32. super().__init__()
  33. self.expert, self.opt, self.name = expert, opt, name
  34. self.args_schema = args_schema = tuple(args_schema or ())
  35. self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
  36. assert args_schema or kwargs_schema, "expert must receive at least one positional or keyword input." \
  37. " Did you forget to provide args_schema/kwargs_schema?"
  38. if outputs_schema is None:
  39. # run expert once to get outputs schema
  40. dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
  41. dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
  42. dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
  43. outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
  44. self.outputs_schema = outputs_schema
  45. self.forward_schema = (self.args_schema, self.kwargs_schema)
  46. self.backward_schema = (self.forward_schema, self.outputs_schema) # original inputs and grad w.r.t. outputs
  47. self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
  48. self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
  49. def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  50. """
  51. Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
  52. To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
  53. Subclassing:
  54. This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
  55. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
  56. .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
  57. .. For now, either register all buffers as outputs or avoid stateful experts
  58. """
  59. args, kwargs = nested_pack(inputs, structure=self.forward_schema)
  60. with torch.no_grad():
  61. outputs = self.expert(*args, **kwargs)
  62. # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
  63. return tuple(nested_flatten(outputs))
  64. def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  65. """
  66. Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
  67. To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
  68. Subclassing:
  69. This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
  70. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
  71. Runtime doesn't guarantee that backward will be performed in the same order and for the same data
  72. as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
  73. .. todo correct state handling (see forward)
  74. Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
  75. """
  76. (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
  77. with torch.enable_grad():
  78. args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
  79. else tensor.detach() for tensor in args]
  80. kwargs = {input_key: (tensor.detach().requires_grad_(True)
  81. if tensor.dtype in (torch.half, torch.float, torch.double)
  82. else tensor.detach())
  83. for input_key, tensor in kwargs.items()}
  84. outputs = self.expert(*args, **kwargs)
  85. assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
  86. outputs_flat = tuple(nested_flatten(outputs))
  87. grad_outputs_flat = tuple(map(
  88. lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
  89. nested_flatten(grad_outputs), outputs_flat))
  90. torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
  91. create_graph=False, retain_graph=False)
  92. self.apply_gradients()
  93. return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
  94. for x in nested_flatten((args, kwargs)))
  95. def apply_gradients(self) -> None:
  96. """
  97. Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
  98. """
  99. self.opt.step()
  100. self.opt.zero_grad()
  101. def get_info(self) -> Dict[str, Any]:
  102. """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
  103. return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
  104. keyword_names=tuple(self.kwargs_schema.keys()))
  105. def get_pools(self) -> Sequence[TaskPool]:
  106. """ return all pools that should be processed by ``Runtime`` """
  107. return self.forward_pool, self.backward_pool