expert_backend.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from typing import Any, Callable, Dict, Sequence, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from hivemind.moe.server.task_pool import TaskPool
  5. from hivemind.utils.logging import get_logger
  6. from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
  7. from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
  8. logger = get_logger(__name__)
  9. class ExpertBackend:
  10. """
  11. ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
  12. By default, ExpertBackend handles three types of requests:
  13. - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
  14. - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
  15. - get_info - return expert metadata. Not batched.
  16. :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
  17. - Experts must always receive the same set of args and kwargs and produce output tensors of same type
  18. - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
  19. - We recommend using experts that are ~invariant to the order in which they process batches
  20. - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
  21. you should explicitly register these random variables as model inputs or outputs.
  22. See hivemind.utils.custom_layers.DeterministicDropout for an example
  23. :param optimizer: torch optimizer to be applied on every backward call
  24. :param scheduler: a function to create the learning rate scheduler for the expert
  25. :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
  26. :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
  27. :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
  28. :param num_warmup_steps: the number of warmup steps for LR schedule
  29. :param num_total_steps: the total number of steps for LR schedule
  30. :param clip_grad_norm: maximum gradient norm used for clipping
  31. :param kwargs: extra parameters to be forwarded into TaskPool.__init__
  32. """
  33. def __init__(
  34. self,
  35. name: str,
  36. expert: nn.Module,
  37. optimizer: torch.optim.Optimizer,
  38. *,
  39. scheduler: Callable = None,
  40. args_schema: Tuple[BatchTensorDescriptor, ...] = None,
  41. kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
  42. outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
  43. num_warmup_steps: int = None,
  44. num_total_steps: int = None,
  45. clip_grad_norm: float = None,
  46. **kwargs,
  47. ):
  48. super().__init__()
  49. self.expert, self.optimizer, self.name = expert, optimizer, name
  50. if scheduler is None:
  51. self.scheduler = None
  52. else:
  53. assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
  54. self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
  55. self.clip_grad_norm = clip_grad_norm
  56. self.args_schema = args_schema = tuple(args_schema or ())
  57. self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
  58. assert args_schema or kwargs_schema, (
  59. "expert must receive at least one positional or keyword input."
  60. " Did you forget to provide args_schema/kwargs_schema?"
  61. )
  62. if outputs_schema is None:
  63. # run expert once to get outputs schema
  64. dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
  65. dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
  66. dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
  67. outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
  68. self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
  69. self.outputs_schema = outputs_schema # outputs from forward
  70. self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
  71. self.grad_inputs_schema = self.forward_schema # outputs from backward
  72. self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
  73. self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
  74. self.update_count = 0
  75. self.examples_processed = 0
  76. def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  77. """
  78. Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
  79. To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
  80. Subclassing:
  81. This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
  82. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
  83. .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
  84. .. For now, either register all buffers as outputs or avoid stateful experts
  85. """
  86. args, kwargs = nested_pack(inputs, structure=self.forward_schema)
  87. if args[0].shape[0] == 0:
  88. raise RuntimeError("Batch should contain more than 0 samples")
  89. with torch.no_grad():
  90. outputs = self.expert(*args, **kwargs)
  91. # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
  92. return tuple(nested_flatten(outputs))
  93. def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  94. """
  95. Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
  96. To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
  97. Subclassing:
  98. This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
  99. It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
  100. Runtime doesn't guarantee that backward will be performed in the same order and for the same data
  101. as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
  102. .. todo correct state handling (see forward)
  103. Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
  104. """
  105. (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
  106. with torch.enable_grad():
  107. args = [
  108. tensor.detach().requires_grad_(True)
  109. if tensor.dtype in (torch.half, torch.float, torch.double)
  110. else tensor.detach()
  111. for tensor in args
  112. ]
  113. kwargs = {
  114. input_key: (tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach())
  115. for input_key, tensor in kwargs.items()
  116. }
  117. batch_size = args[0].size(0)
  118. outputs = self.expert(*args, **kwargs)
  119. assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
  120. outputs_flat = tuple(nested_flatten(outputs))
  121. grad_outputs_flat = tuple(
  122. map(
  123. lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
  124. nested_flatten(grad_outputs),
  125. outputs_flat,
  126. )
  127. )
  128. torch.autograd.backward(
  129. outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
  130. )
  131. self.apply_gradients(batch_size)
  132. return tuple(
  133. x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
  134. )
  135. def apply_gradients(self, batch_size) -> None:
  136. """
  137. Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
  138. """
  139. if self.clip_grad_norm is not None:
  140. torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
  141. self.optimizer.step()
  142. self.optimizer.zero_grad()
  143. if self.scheduler is not None:
  144. self.scheduler.step()
  145. self.update_count += 1
  146. self.examples_processed += batch_size
  147. def get_stats(self) -> Dict:
  148. """
  149. Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
  150. """
  151. return {"updates": self.update_count, "examples_processed": self.examples_processed}
  152. def get_full_state(self) -> Dict:
  153. """
  154. Return the current state of the expert (including batch processing statistics)
  155. """
  156. full_state = {
  157. "stats": self.get_stats(),
  158. "model": self.expert.state_dict(),
  159. "optimizer": self.optimizer.state_dict(),
  160. "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
  161. }
  162. return full_state
  163. def load_full_state(self, state_dict: Dict):
  164. if "stats" in state_dict:
  165. self.update_count = state_dict["stats"]["updates"]
  166. self.examples_processed = state_dict["stats"]["examples_processed"]
  167. else:
  168. logger.warning(f"Batch processing stats missing for expert {self.name}")
  169. self.expert.load_state_dict(state_dict["model"])
  170. if "optimizer" in state_dict:
  171. self.optimizer.load_state_dict(state_dict["optimizer"])
  172. else:
  173. logger.warning(f"Optimizer state missing for expert {self.name}")
  174. if self.scheduler is not None and "scheduler" in state_dict:
  175. self.scheduler.load_state_dict(state_dict["scheduler"])
  176. else:
  177. logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
  178. def get_info(self) -> Dict[str, Any]:
  179. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  180. return dict(
  181. forward_schema=self.forward_schema,
  182. outputs_schema=self.outputs_schema,
  183. keyword_names=tuple(self.kwargs_schema.keys()),
  184. )
  185. def get_pools(self) -> Sequence[TaskPool]:
  186. """return all pools that should be processed by ``Runtime``"""
  187. return self.forward_pool, self.backward_pool