|
@@ -1,21 +1,28 @@
|
|
|
-import heapq
|
|
|
-import random
|
|
|
-import threading
|
|
|
-from typing import Optional, Tuple, List, Dict
|
|
|
+from typing import Any, Dict, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from torch.autograd.function import once_differentiable
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind import nested_compare, nested_flatten, get_dht_time
|
|
|
-from hivemind.moe.client.expert import _RemoteModuleCall, DUMMY
|
|
|
-from hivemind.moe.server.expert_uid import ExpertUID
|
|
|
-from hivemind.utils import DHTExpiration
|
|
|
-
|
|
|
-
|
|
|
-class LoadBalancedExpert(nn.Module):
|
|
|
+from hivemind.moe.client.balancer import ExpertBalancer
|
|
|
+from hivemind.moe.client.expert import DUMMY
|
|
|
+from hivemind.proto import runtime_pb2
|
|
|
+from hivemind.utils import (
|
|
|
+ deserialize_torch_tensor,
|
|
|
+ get_logger,
|
|
|
+ nested_compare,
|
|
|
+ nested_flatten,
|
|
|
+ nested_pack,
|
|
|
+ serialize_torch_tensor,
|
|
|
+)
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class BalancedRemoteExpert(nn.Module):
|
|
|
"""
|
|
|
- A torch module that dynamically assigns weights to one RemoteExpert from a pool.
|
|
|
+ A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
|
|
|
ToDo docstring, similar to RemoteMixtureOfExperts
|
|
|
"""
|
|
|
|
|
@@ -28,69 +35,19 @@ class LoadBalancedExpert(nn.Module):
|
|
|
forward_timeout: Optional[float] = None,
|
|
|
backward_timeout: Optional[float] = None,
|
|
|
detect_anomalies: bool = False,
|
|
|
- refresh_period: float = 30.,
|
|
|
- **dht_kwargs,
|
|
|
+ update_period: float = 30.0,
|
|
|
+ backward_task_size_multiplier: float = 2.5,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
super().__init__()
|
|
|
- assert len(grid_size) == 1, "only 1d grids are supported for now"
|
|
|
- self.dht, self.dht_kwargs, self.uid_prefix, self.grid_size = dht, dht_kwargs, uid_prefix, grid_size
|
|
|
+ if uid_prefix.endswith(".0."):
|
|
|
+ logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
|
|
|
+ assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
|
|
|
+ self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
- self.detect_anomalies = detect_anomalies
|
|
|
-
|
|
|
- self.active_experts: Dict[ExpertUID, DHTExpiration] = {}
|
|
|
- self.banned_experts: Dict[ExpertUID, DHTExpiration] = {}
|
|
|
- self.expert_queue: List[Tuple[float, float, ExpertUID]] = []
|
|
|
+ self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
|
|
|
+ self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
|
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
- self.refresh_period, self.last_refresh = refresh_period, 0.0
|
|
|
- self.should_refresh_experts, self.refresh_complete = threading.Event(), threading.Event()
|
|
|
-
|
|
|
- def fetch_experts_in_background(self):
|
|
|
- while True:
|
|
|
- time_to_next_update = max(0.0, self.last_update + self.refresh_period - get_dht_time())
|
|
|
- try:
|
|
|
- self.should_refresh_experts.wait(timeout=time_to_next_update)
|
|
|
- # update triggered by main thread
|
|
|
- except TimeoutError:
|
|
|
- pass # update triggered by refresh_period
|
|
|
-
|
|
|
- TODO_FETCH_MORE_EXPERTS_HERE
|
|
|
-
|
|
|
- # state:
|
|
|
- # * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
|
|
|
- # * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
|
|
|
- # * maintain a min-heap queue of (load, rng, expert) tuples
|
|
|
- # * update_triggered, update_finished: threading.Event; lock: threading.Lock
|
|
|
- #
|
|
|
- # update experts in background, while True:
|
|
|
- # * wait for 30s or for update_triggered, whichever comes first
|
|
|
- # * for expert, expiration_time in fetch_experts_from_dht():
|
|
|
- # * * if expert in banned and expiration_time <= self.blacklist[expert]:
|
|
|
- # * * * continue # expert is still banned
|
|
|
- # * * else: add expert to min-heap, intitialize throughput
|
|
|
- # * update_complete.set()
|
|
|
- #
|
|
|
- # on forward/backward:
|
|
|
- # pass (queue, blacklist, update_triggered, update_finished) to the autograd function
|
|
|
- #
|
|
|
- # forward/backward autograd function
|
|
|
- # while True:
|
|
|
- # * while len(available experts) == 0:
|
|
|
- # * * update_finished.clear()
|
|
|
- # * * update_triggered.set()
|
|
|
- # * * update_finished.wait()
|
|
|
- # * with threading.lock:
|
|
|
- # * * load, _, expert = queue.heappop_min()
|
|
|
- # * * expert_throughput_ema, expert_expiration_time = get ema from dict
|
|
|
- # * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
|
|
|
- # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
|
|
|
- # * try:
|
|
|
- # * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
|
|
|
- # * * * outputs = call_forward_or_backward()
|
|
|
- # * * expert_throughput_ema.update(measured_ema)
|
|
|
- # * * return outputs # <--------- this is the desired exit point
|
|
|
- # * except DidNotRespondCorrectly:
|
|
|
- # * * banned_experts[expert] = expert_expiration_time
|
|
|
- # * continue # try again
|
|
|
|
|
|
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|
|
@@ -103,8 +60,6 @@ class LoadBalancedExpert(nn.Module):
|
|
|
assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
|
|
|
kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
if self._expert_info is None:
|
|
|
raise NotImplementedError()
|
|
|
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
|
@@ -114,17 +69,91 @@ class LoadBalancedExpert(nn.Module):
|
|
|
if not nested_compare(forward_inputs, self.info["forward_schema"]):
|
|
|
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
|
|
|
- flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
|
|
|
+ flat_inputs = nested_flatten(forward_inputs)
|
|
|
+ forward_task_size = flat_inputs[0].shape[0]
|
|
|
+
|
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
|
+ flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
|
|
|
+ self.uid,
|
|
|
+ self.expert_balancer,
|
|
|
+ self.info,
|
|
|
+ self.forward_timeout,
|
|
|
+ self.backward_timeout,
|
|
|
+ forward_task_size,
|
|
|
+ forward_task_size * self.backward_task_size_multiplier,
|
|
|
+ *flat_inputs)
|
|
|
+
|
|
|
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
|
|
|
|
|
|
@property
|
|
|
def info(self):
|
|
|
- if self._expert_info is None:
|
|
|
- # grab some expert to set ensemble output shape
|
|
|
- proj_device = self.proj.weight.device
|
|
|
- dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
|
|
|
- dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
|
|
|
- dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
|
|
|
- self._expert_info = dummy_experts[0].info
|
|
|
+ while self._expert_info is None:
|
|
|
+ try:
|
|
|
+ with self.expert_balancer.use_another_expert(1) as chosen_expert:
|
|
|
+ self._expert_info = chosen_expert.info
|
|
|
+ except BaseException as e:
|
|
|
+ logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
|
|
|
return self._expert_info
|
|
|
+
|
|
|
+
|
|
|
+class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
+ """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def forward(
|
|
|
+ ctx,
|
|
|
+ dummy: torch.Tensor,
|
|
|
+ expert_balancer: ExpertBalancer,
|
|
|
+ info: Dict[str, Any],
|
|
|
+ forward_timeout: float,
|
|
|
+ backward_timeout: float,
|
|
|
+ forward_task_size: float,
|
|
|
+ backward_task_size: float,
|
|
|
+ *inputs: torch.Tensor,
|
|
|
+ ) -> Tuple[torch.Tensor, ...]:
|
|
|
+ # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
+ # detach to avoid pickling the computation graph
|
|
|
+ ctx.expert_balancer, ctx.info = expert_balancer, info
|
|
|
+ ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
|
|
|
+ ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
|
|
|
+ inputs = tuple(tensor.cpu().detach() for tensor in inputs)
|
|
|
+ ctx.save_for_backward(*inputs)
|
|
|
+
|
|
|
+ serialized_tensors = [
|
|
|
+ serialize_torch_tensor(inp, proto.compression)
|
|
|
+ for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
+ ]
|
|
|
+ forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
+ outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
+ break
|
|
|
+ except BaseException as e:
|
|
|
+ logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
+ return tuple(deserialized_outputs)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @once_differentiable
|
|
|
+ def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
+ grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
+ inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
+ backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
+ serialized_tensors = [
|
|
|
+ serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
+ ]
|
|
|
+ backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
|
|
|
+ grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
|
+ break
|
|
|
+ except BaseException as e:
|
|
|
+ logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
|
|
|
+ raise
|
|
|
+ deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
+ return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|