justheuristic 4 жил өмнө
parent
commit
dee9854d79

+ 110 - 81
hivemind/moe/client/balanced_expert.py

@@ -1,21 +1,28 @@
-import heapq
-import random
-import threading
-from typing import Optional, Tuple, List, Dict
+from typing import Any, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
 import hivemind
-from hivemind import nested_compare, nested_flatten, get_dht_time
-from hivemind.moe.client.expert import _RemoteModuleCall, DUMMY
-from hivemind.moe.server.expert_uid import ExpertUID
-from hivemind.utils import DHTExpiration
-
-
-class LoadBalancedExpert(nn.Module):
+from hivemind.moe.client.balancer import ExpertBalancer
+from hivemind.moe.client.expert import DUMMY
+from hivemind.proto import runtime_pb2
+from hivemind.utils import (
+    deserialize_torch_tensor,
+    get_logger,
+    nested_compare,
+    nested_flatten,
+    nested_pack,
+    serialize_torch_tensor,
+)
+
+logger = get_logger(__name__)
+
+
+class BalancedRemoteExpert(nn.Module):
     """
-    A torch module that dynamically assigns weights to one RemoteExpert from a pool.
+    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
     ToDo docstring, similar to RemoteMixtureOfExperts
     """
 
@@ -28,69 +35,19 @@ class LoadBalancedExpert(nn.Module):
         forward_timeout: Optional[float] = None,
         backward_timeout: Optional[float] = None,
         detect_anomalies: bool = False,
-        refresh_period: float = 30.,
-        **dht_kwargs,
+        update_period: float = 30.0,
+        backward_task_size_multiplier: float = 2.5,
+        **kwargs,
     ):
         super().__init__()
-        assert len(grid_size) == 1, "only 1d grids are supported for now"
-        self.dht, self.dht_kwargs, self.uid_prefix, self.grid_size = dht, dht_kwargs, uid_prefix, grid_size
+        if uid_prefix.endswith(".0."):
+            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
+        assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
+        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
-        self.detect_anomalies = detect_anomalies
-
-        self.active_experts: Dict[ExpertUID, DHTExpiration] = {}
-        self.banned_experts: Dict[ExpertUID, DHTExpiration] = {}
-        self.expert_queue: List[Tuple[float, float, ExpertUID]] = []
+        self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
+        self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
         self._expert_info = None  # expert['info'] from one of experts in the grid
-        self.refresh_period, self.last_refresh = refresh_period, 0.0
-        self.should_refresh_experts, self.refresh_complete = threading.Event(), threading.Event()
-
-    def fetch_experts_in_background(self):
-        while True:
-            time_to_next_update = max(0.0, self.last_update + self.refresh_period - get_dht_time())
-            try:
-                self.should_refresh_experts.wait(timeout=time_to_next_update)
-                # update triggered by main thread
-            except TimeoutError:
-                pass  # update triggered by refresh_period
-
-            TODO_FETCH_MORE_EXPERTS_HERE
-
-            # state:
-            # * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
-            # * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
-            # * maintain a min-heap queue of (load, rng, expert) tuples
-            # * update_triggered, update_finished: threading.Event; lock: threading.Lock
-            #
-            # update experts in background, while True:
-            # * wait for 30s or for update_triggered, whichever comes first
-            # * for expert, expiration_time in fetch_experts_from_dht():
-            # * * if expert in banned and expiration_time <= self.blacklist[expert]:
-            # * * * continue # expert is still banned
-            # * * else: add expert to min-heap, intitialize throughput
-            # * update_complete.set()
-            #
-            # on forward/backward:
-            # pass (queue, blacklist, update_triggered, update_finished) to the autograd function
-            #
-            # forward/backward autograd function
-            # while True:
-            # * while len(available experts) == 0:
-            # * * update_finished.clear()
-            # * * update_triggered.set()
-            # * * update_finished.wait()
-            # * with threading.lock:
-            # * * load, _, expert = queue.heappop_min()
-            # * * expert_throughput_ema, expert_expiration_time = get ema from dict
-            # * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
-            # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
-            # * try:
-            # * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
-            # * * * outputs = call_forward_or_backward()
-            # * * expert_throughput_ema.update(measured_ema)
-            # * * return outputs      # <--------- this is the desired exit point
-            # * except DidNotRespondCorrectly:
-            # * * banned_experts[expert] = expert_expiration_time
-            # * continue # try again
 
     def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
@@ -103,8 +60,6 @@ class LoadBalancedExpert(nn.Module):
         assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
         kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
 
-
-
         if self._expert_info is None:
             raise NotImplementedError()
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
@@ -114,17 +69,91 @@ class LoadBalancedExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+        flat_inputs = nested_flatten(forward_inputs)
+        forward_task_size = flat_inputs[0].shape[0]
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
+                                                       self.uid,
+                                                       self.expert_balancer,
+                                                       self.info,
+                                                       self.forward_timeout,
+                                                       self.backward_timeout,
+                                                       forward_task_size,
+                                                       forward_task_size * self.backward_task_size_multiplier,
+                                                       *flat_inputs)
+
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
     @property
     def info(self):
-        if self._expert_info is None:
-            # grab some expert to set ensemble output shape
-            proj_device = self.proj.weight.device
-            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
-            dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
-            self._expert_info = dummy_experts[0].info
+        while self._expert_info is None:
+            try:
+                with self.expert_balancer.use_another_expert(1) as chosen_expert:
+                    self._expert_info = chosen_expert.info
+            except BaseException as e:
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
         return self._expert_info
+
+
+class _BalancedRemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
+
+    @staticmethod
+    def forward(
+            ctx,
+            dummy: torch.Tensor,
+            expert_balancer: ExpertBalancer,
+            info: Dict[str, Any],
+            forward_timeout: float,
+            backward_timeout: float,
+            forward_task_size: float,
+            backward_task_size: float,
+            *inputs: torch.Tensor,
+            ) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        # detach to avoid pickling the computation graph
+        ctx.expert_balancer, ctx.info = expert_balancer, info
+        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
+        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
+        ctx.save_for_backward(*inputs)
+
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
+        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
+                raise
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        return tuple(deserialized_outputs)
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        ]
+        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
+                raise
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 15 - 8
hivemind/moe/client/balancer.py

@@ -2,14 +2,13 @@ import heapq
 import random
 import threading
 from contextlib import contextmanager
-from typing import Tuple, List, Dict
+from typing import Dict, List, Tuple
 
-from hivemind import TimedStorage, Endpoint, RemoteExpert
+from hivemind import Endpoint, RemoteExpert, TimedStorage
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import ExpertPrefix
-from hivemind.moe.server.expert_uid import ExpertUID
+from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_logger, get_dht_time
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -51,12 +50,15 @@ class ExpertBalancer:
                         maybe_banned = self.blacklist.get(uid)
                         if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
                             self._add_expert(uid, endpoint, expiration_time)
-
+                        else:
+                            logger.debug(f"Not adding expert {uid} (blacklisted).")
                     except Exception as e:
                         logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
             else:
                 logger.warning(f"Could not refresh experts, dht info key contains {response}, "
                                f"will retry in {time_to_next_update}s")
+            if len(self.queue) == 0:
+                logger.warning("Update routine finished, but still no experts available.")
 
             self.last_update = get_dht_time()
             self.update_finished.set()
@@ -65,11 +67,14 @@ class ExpertBalancer:
         with self.lock:
             self.experts.store(uid, endpoint, expiration_time)
             if uid not in self.uid_to_queue:
+                logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
                 self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
                 base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
                 heap_entry = (base_load, random.random(), uid)
                 heapq.heappush(self.queue, heap_entry)
                 self.uid_to_queue[uid] = heap_entry
+            else:
+                logger.debug(f"Refreshing existing expert: {uid}, new expiration time = {expiration_time:.3f}.")
 
     def _ban_expert(self, uid: ExpertUID):
         with self.lock:
@@ -79,9 +84,10 @@ class ExpertBalancer:
             self.uid_to_queue.pop(uid, None)
             self.throughputs.pop(uid, None)
             del self.experts[uid]
+            logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
 
     @contextmanager
-    def lend_expert(self, task_size: int):
+    def use_another_expert(self, task_size: float) -> RemoteExpert:
         while True:
             if len(self.queue) == 0:
                 self.update_finished.clear()
@@ -109,8 +115,9 @@ class ExpertBalancer:
                 break
         try:
             with self.throughputs[uid].update_threadsafe(task_size):
+                logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
                 yield RemoteExpert(uid, maybe_endpoint.value)
-        except BaseException as e:
+        except BaseException:
             self._ban_expert(uid)
             raise
 

+ 3 - 3
hivemind/optim/performance_ema.py

@@ -17,14 +17,14 @@ class PerformanceEMA:
         self.paused = paused
         self.lock = Lock()
 
-    def update(self, task_size: int, interval: float) -> float:
+    def update(self, task_size: float, interval: float) -> float:
         """
         :param task_size: how many items were processed since last call
         :param interval: optionally provide the time delta it took to process this task
         :returns: current estimate of performance (samples per second), but at most
         """
-        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
         assert task_size > 0, f"Can't register processing {task_size} samples"
+        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
         if not self.paused:
             self.timestamp, old_timestamp = get_dht_time(), self.timestamp
             interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
@@ -48,7 +48,7 @@ class PerformanceEMA:
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
 
     @contextmanager
-    def update_threadsafe(self, task_size: int):
+    def update_threadsafe(self, task_size: float):
         """measure the EMA throughout of the code that runs inside the context"""
         start_timestamp = get_dht_time()
         yield