client.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from typing import Any, Dict, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  5. from torch.autograd.function import once_differentiable
  6. import hivemind
  7. from load_balancer import LoadBalancer
  8. from hivemind.moe.client.expert import DUMMY, expert_forward
  9. from hivemind.proto import runtime_pb2
  10. from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
  11. from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
  12. logger = get_logger(__name__)
  13. MAX_NODES = 99999
  14. class BalancedRemoteExpert(nn.Module):
  15. """
  16. A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
  17. ToDo docstring, similar to hivemind.RemoteExpert
  18. """
  19. def __init__(
  20. self,
  21. *,
  22. dht: hivemind.DHT,
  23. uid_prefix: str,
  24. grid_size: Tuple[int, ...] = (1, MAX_NODES),
  25. forward_timeout: Optional[float] = None,
  26. backward_timeout: Optional[float] = None,
  27. update_period: float = 30.0,
  28. backward_task_size_multiplier: float = 2.5,
  29. **kwargs,
  30. ):
  31. super().__init__()
  32. if uid_prefix.endswith(".0."):
  33. logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
  34. assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
  35. self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
  36. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  37. self.backward_task_size_multiplier = backward_task_size_multiplier
  38. self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
  39. self._expert_info = None # expert['info'] from one of experts in the grid
  40. def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
  41. """
  42. Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
  43. :param args: input tensors that will be passed to each expert after input, batch-first
  44. :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
  45. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  46. """
  47. assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
  48. kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
  49. if self._expert_info is None:
  50. raise NotImplementedError()
  51. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  52. forward_inputs = (args, kwargs)
  53. if not nested_compare(forward_inputs, self.info["forward_schema"]):
  54. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  55. flat_inputs = list(nested_flatten(forward_inputs))
  56. forward_task_size = flat_inputs[0].shape[0]
  57. # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
  58. flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
  59. self.expert_balancer,
  60. self.info,
  61. self.forward_timeout,
  62. self.backward_timeout,
  63. forward_task_size,
  64. forward_task_size * self.backward_task_size_multiplier,
  65. *flat_inputs)
  66. return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
  67. @property
  68. def info(self):
  69. while self._expert_info is None:
  70. try:
  71. with self.expert_balancer.use_another_expert(1) as chosen_expert:
  72. self._expert_info = chosen_expert.info
  73. except BaseException as e:
  74. logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
  75. return self._expert_info
  76. class _BalancedRemoteModuleCall(torch.autograd.Function):
  77. """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
  78. @staticmethod
  79. def forward(
  80. ctx,
  81. dummy: torch.Tensor,
  82. expert_balancer: LoadBalancer,
  83. info: Dict[str, Any],
  84. forward_timeout: float,
  85. backward_timeout: float,
  86. forward_task_size: float,
  87. backward_task_size: float,
  88. *inputs: torch.Tensor,
  89. ) -> Tuple[torch.Tensor, ...]:
  90. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  91. # detach to avoid pickling the computation graph
  92. ctx.expert_balancer, ctx.info = expert_balancer, info
  93. ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
  94. ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
  95. inputs = tuple(tensor.cpu().detach() for tensor in inputs)
  96. ctx.save_for_backward(*inputs)
  97. serialized_tensors = [
  98. serialize_torch_tensor(inp, proto.compression)
  99. for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
  100. ]
  101. while True:
  102. try:
  103. with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
  104. deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
  105. chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
  106. break
  107. except BaseException as e:
  108. logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
  109. return tuple(deserialized_outputs)
  110. @staticmethod
  111. @once_differentiable
  112. def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
  113. raise NotImplementedError("Backward is not yet implemented in this example")
  114. # grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  115. # inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
  116. # backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
  117. # serialized_tensors = [
  118. # serialize_torch_tensor(tensor, proto.compression)
  119. # for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  120. # ]
  121. # while True:
  122. # try:
  123. # with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
  124. # backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
  125. # grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
  126. # break
  127. # except BaseException as e:
  128. # logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
  129. # deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
  130. # return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)