balanced_expert.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from typing import Any, Dict, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd.function import once_differentiable
  5. import hivemind
  6. from hivemind.moe.client.balancer import ExpertBalancer
  7. from hivemind.moe.client.expert import DUMMY
  8. from hivemind.proto import runtime_pb2
  9. from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
  10. from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
  11. logger = get_logger(__name__)
  12. class BalancedRemoteExpert(nn.Module):
  13. """
  14. A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
  15. ToDo docstring, similar to RemoteMixtureOfExperts
  16. """
  17. def __init__(
  18. self,
  19. *,
  20. dht: hivemind.DHT,
  21. uid_prefix: str,
  22. grid_size: Tuple[int, ...],
  23. forward_timeout: Optional[float] = None,
  24. backward_timeout: Optional[float] = None,
  25. update_period: float = 30.0,
  26. backward_task_size_multiplier: float = 2.5,
  27. **kwargs,
  28. ):
  29. super().__init__()
  30. if uid_prefix.endswith(".0."):
  31. logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
  32. assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
  33. self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
  34. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  35. self.backward_task_size_multiplier = backward_task_size_multiplier
  36. self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
  37. self._expert_info = None # expert['info'] from one of experts in the grid
  38. def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
  39. """
  40. Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
  41. :param args: input tensors that will be passed to each expert after input, batch-first
  42. :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
  43. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  44. """
  45. assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
  46. kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
  47. if self._expert_info is None:
  48. raise NotImplementedError()
  49. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  50. forward_inputs = (args, kwargs)
  51. if not nested_compare(forward_inputs, self.info["forward_schema"]):
  52. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  53. flat_inputs = list(nested_flatten(forward_inputs))
  54. forward_task_size = flat_inputs[0].shape[0]
  55. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  56. flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
  57. self.expert_balancer,
  58. self.info,
  59. self.forward_timeout,
  60. self.backward_timeout,
  61. forward_task_size,
  62. forward_task_size * self.backward_task_size_multiplier,
  63. *flat_inputs)
  64. return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
  65. @property
  66. def info(self):
  67. while self._expert_info is None:
  68. try:
  69. with self.expert_balancer.use_another_expert(1) as chosen_expert:
  70. self._expert_info = chosen_expert.info
  71. except BaseException as e:
  72. logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
  73. return self._expert_info
  74. class _BalancedRemoteModuleCall(torch.autograd.Function):
  75. """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
  76. @staticmethod
  77. def forward(
  78. ctx,
  79. dummy: torch.Tensor,
  80. expert_balancer: ExpertBalancer,
  81. info: Dict[str, Any],
  82. forward_timeout: float,
  83. backward_timeout: float,
  84. forward_task_size: float,
  85. backward_task_size: float,
  86. *inputs: torch.Tensor,
  87. ) -> Tuple[torch.Tensor, ...]:
  88. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  89. # detach to avoid pickling the computation graph
  90. ctx.expert_balancer, ctx.info = expert_balancer, info
  91. ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
  92. ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
  93. inputs = tuple(tensor.cpu().detach() for tensor in inputs)
  94. ctx.save_for_backward(*inputs)
  95. serialized_tensors = [
  96. serialize_torch_tensor(inp, proto.compression)
  97. for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
  98. ]
  99. while True:
  100. try:
  101. with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
  102. forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
  103. outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
  104. break
  105. except BaseException as e:
  106. logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
  107. deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
  108. return tuple(deserialized_outputs)
  109. @staticmethod
  110. @once_differentiable
  111. def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  112. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  113. inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
  114. backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
  115. serialized_tensors = [
  116. serialize_torch_tensor(tensor, proto.compression)
  117. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  118. ]
  119. while True:
  120. try:
  121. with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
  122. backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
  123. grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
  124. break
  125. except BaseException as e:
  126. logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
  127. deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
  128. return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)