expert_backend.py 5.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import Dict, Sequence, Any, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from .task_pool import TaskPool
  5. from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map
  6. class ExpertBackend(nn.Module):
  7. def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
  8. args_schema: Tuple[BatchTensorProto, ...] = None,
  9. kwargs_schema: Dict[str, BatchTensorProto] = None,
  10. outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
  11. **kwargs):
  12. """
  13. ExpertBackend implements how a given expert processes tasks.
  14. By default, there are two tasks:
  15. * forward receives inputs and produces outputs
  16. * backward receives gradients w.r.t. outputs, computes gradients w.r.t. inputs and trains the expert
  17. All incoming tasks are grouped by type (forward/backward) and sent into the corresponding pool,
  18. where tasks are grouped into minibatches and prepared for processing on device;
  19. The results are dispatched to task authors with SharedFuture.set_result.
  20. :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
  21. * Experts must always receive the same set of *args and **kwargs and produce output tensors of same type
  22. * All *args, **kwargs and outputs must be *tensors* where 0-th dimension represents to batch size
  23. * We recommend using experts that are ~invariant to the order in which they process batches
  24. :param opt: torch optimizer to be applied on every backward call
  25. :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
  26. :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
  27. :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
  28. :param kwargs: extra parameters to be forwarded into TaskPool.__init__
  29. """
  30. super().__init__()
  31. self.expert, self.opt, self.name = expert, opt, name
  32. self.args_schema = args_schema = tuple(args_schema or ())
  33. self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
  34. assert args_schema or kwargs_schema, "expert must receive at least one positional or keyword input." \
  35. " Did you forget to provide args_schema/kwargs_schema?"
  36. if outputs_schema is None:
  37. # run expert once to get outputs schema
  38. dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
  39. dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
  40. dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
  41. outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs)
  42. self.outputs_schema = outputs_schema
  43. self.forward_schema = (self.args_schema, self.kwargs_schema)
  44. self.backward_schema = (self.forward_schema, self.outputs_schema) # original inputs and grad w.r.t. outputs
  45. self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
  46. self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
  47. def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  48. args, kwargs = nested_pack(inputs, structure=self.forward_schema)
  49. with torch.no_grad():
  50. outputs = self.expert(*args, **kwargs)
  51. # Note: TaskPool requires function to accept and return a **list** of values, we pack/unpack it on client side
  52. return tuple(nested_flatten(outputs))
  53. def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  54. (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
  55. with torch.enable_grad():
  56. args = [tensor.detach().requires_grad_(True) for tensor in args]
  57. kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
  58. outputs = self.expert(*args, **kwargs)
  59. assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
  60. outputs_flat = tuple(nested_flatten(outputs))
  61. grad_outputs_flat = tuple(map(
  62. lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
  63. nested_flatten(grad_outputs), outputs_flat))
  64. torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
  65. create_graph=False, retain_graph=False)
  66. self.apply_gradients()
  67. return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
  68. for x in nested_flatten((args, kwargs)))
  69. def apply_gradients(self) -> None:
  70. self.opt.step()
  71. self.opt.zero_grad()
  72. def get_pools(self) -> Sequence[TaskPool]:
  73. return self.forward_pool, self.backward_pool
  74. def get_info(self) -> Dict[str, Any]:
  75. return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
  76. keyword_names=tuple(self.kwargs_schema.keys()))