justheuristic 4 anos atrás
pai
commit
dee9854d79

+ 110 - 81
hivemind/moe/client/balanced_expert.py

@@ -1,21 +1,28 @@
-import heapq
-import random
-import threading
-from typing import Optional, Tuple, List, Dict
+from typing import Any, Dict, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
-from hivemind import nested_compare, nested_flatten, get_dht_time
-from hivemind.moe.client.expert import _RemoteModuleCall, DUMMY
-from hivemind.moe.server.expert_uid import ExpertUID
-from hivemind.utils import DHTExpiration
-
-
-class LoadBalancedExpert(nn.Module):
+from hivemind.moe.client.balancer import ExpertBalancer
+from hivemind.moe.client.expert import DUMMY
+from hivemind.proto import runtime_pb2
+from hivemind.utils import (
+    deserialize_torch_tensor,
+    get_logger,
+    nested_compare,
+    nested_flatten,
+    nested_pack,
+    serialize_torch_tensor,
+)
+
+logger = get_logger(__name__)
+
+
+class BalancedRemoteExpert(nn.Module):
     """
     """
-    A torch module that dynamically assigns weights to one RemoteExpert from a pool.
+    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
     ToDo docstring, similar to RemoteMixtureOfExperts
     ToDo docstring, similar to RemoteMixtureOfExperts
     """
     """
 
 
@@ -28,69 +35,19 @@ class LoadBalancedExpert(nn.Module):
         forward_timeout: Optional[float] = None,
         forward_timeout: Optional[float] = None,
         backward_timeout: Optional[float] = None,
         backward_timeout: Optional[float] = None,
         detect_anomalies: bool = False,
         detect_anomalies: bool = False,
-        refresh_period: float = 30.,
-        **dht_kwargs,
+        update_period: float = 30.0,
+        backward_task_size_multiplier: float = 2.5,
+        **kwargs,
     ):
     ):
         super().__init__()
         super().__init__()
-        assert len(grid_size) == 1, "only 1d grids are supported for now"
-        self.dht, self.dht_kwargs, self.uid_prefix, self.grid_size = dht, dht_kwargs, uid_prefix, grid_size
+        if uid_prefix.endswith(".0."):
+            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
+        assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
+        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
-        self.detect_anomalies = detect_anomalies
-
-        self.active_experts: Dict[ExpertUID, DHTExpiration] = {}
-        self.banned_experts: Dict[ExpertUID, DHTExpiration] = {}
-        self.expert_queue: List[Tuple[float, float, ExpertUID]] = []
+        self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
+        self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
         self._expert_info = None  # expert['info'] from one of experts in the grid
         self._expert_info = None  # expert['info'] from one of experts in the grid
-        self.refresh_period, self.last_refresh = refresh_period, 0.0
-        self.should_refresh_experts, self.refresh_complete = threading.Event(), threading.Event()
-
-    def fetch_experts_in_background(self):
-        while True:
-            time_to_next_update = max(0.0, self.last_update + self.refresh_period - get_dht_time())
-            try:
-                self.should_refresh_experts.wait(timeout=time_to_next_update)
-                # update triggered by main thread
-            except TimeoutError:
-                pass  # update triggered by refresh_period
-
-            TODO_FETCH_MORE_EXPERTS_HERE
-
-            # state:
-            # * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
-            # * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
-            # * maintain a min-heap queue of (load, rng, expert) tuples
-            # * update_triggered, update_finished: threading.Event; lock: threading.Lock
-            #
-            # update experts in background, while True:
-            # * wait for 30s or for update_triggered, whichever comes first
-            # * for expert, expiration_time in fetch_experts_from_dht():
-            # * * if expert in banned and expiration_time <= self.blacklist[expert]:
-            # * * * continue # expert is still banned
-            # * * else: add expert to min-heap, intitialize throughput
-            # * update_complete.set()
-            #
-            # on forward/backward:
-            # pass (queue, blacklist, update_triggered, update_finished) to the autograd function
-            #
-            # forward/backward autograd function
-            # while True:
-            # * while len(available experts) == 0:
-            # * * update_finished.clear()
-            # * * update_triggered.set()
-            # * * update_finished.wait()
-            # * with threading.lock:
-            # * * load, _, expert = queue.heappop_min()
-            # * * expert_throughput_ema, expert_expiration_time = get ema from dict
-            # * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
-            # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
-            # * try:
-            # * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
-            # * * * outputs = call_forward_or_backward()
-            # * * expert_throughput_ema.update(measured_ema)
-            # * * return outputs      # <--------- this is the desired exit point
-            # * except DidNotRespondCorrectly:
-            # * * banned_experts[expert] = expert_expiration_time
-            # * continue # try again
 
 
     def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
     def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
         """
@@ -103,8 +60,6 @@ class LoadBalancedExpert(nn.Module):
         assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
         assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
         kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
         kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
 
 
-
-
         if self._expert_info is None:
         if self._expert_info is None:
             raise NotImplementedError()
             raise NotImplementedError()
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
@@ -114,17 +69,91 @@ class LoadBalancedExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+        flat_inputs = nested_flatten(forward_inputs)
+        forward_task_size = flat_inputs[0].shape[0]
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
+                                                       self.uid,
+                                                       self.expert_balancer,
+                                                       self.info,
+                                                       self.forward_timeout,
+                                                       self.backward_timeout,
+                                                       forward_task_size,
+                                                       forward_task_size * self.backward_task_size_multiplier,
+                                                       *flat_inputs)
+
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
 
     @property
     @property
     def info(self):
     def info(self):
-        if self._expert_info is None:
-            # grab some expert to set ensemble output shape
-            proj_device = self.proj.weight.device
-            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
-            dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
-            self._expert_info = dummy_experts[0].info
+        while self._expert_info is None:
+            try:
+                with self.expert_balancer.use_another_expert(1) as chosen_expert:
+                    self._expert_info = chosen_expert.info
+            except BaseException as e:
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
         return self._expert_info
         return self._expert_info
+
+
+class _BalancedRemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
+
+    @staticmethod
+    def forward(
+            ctx,
+            dummy: torch.Tensor,
+            expert_balancer: ExpertBalancer,
+            info: Dict[str, Any],
+            forward_timeout: float,
+            backward_timeout: float,
+            forward_task_size: float,
+            backward_task_size: float,
+            *inputs: torch.Tensor,
+            ) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        # detach to avoid pickling the computation graph
+        ctx.expert_balancer, ctx.info = expert_balancer, info
+        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
+        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
+        ctx.save_for_backward(*inputs)
+
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
+        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
+                raise
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        return tuple(deserialized_outputs)
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        ]
+        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
+                raise
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 15 - 8
hivemind/moe/client/balancer.py

@@ -2,14 +2,13 @@ import heapq
 import random
 import random
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import Tuple, List, Dict
+from typing import Dict, List, Tuple
 
 
-from hivemind import TimedStorage, Endpoint, RemoteExpert
+from hivemind import Endpoint, RemoteExpert, TimedStorage
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import ExpertPrefix
-from hivemind.moe.server.expert_uid import ExpertUID
+from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
 from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_logger, get_dht_time
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -51,12 +50,15 @@ class ExpertBalancer:
                         maybe_banned = self.blacklist.get(uid)
                         maybe_banned = self.blacklist.get(uid)
                         if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
                         if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
                             self._add_expert(uid, endpoint, expiration_time)
                             self._add_expert(uid, endpoint, expiration_time)
-
+                        else:
+                            logger.debug(f"Not adding expert {uid} (blacklisted).")
                     except Exception as e:
                     except Exception as e:
                         logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
                         logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
             else:
             else:
                 logger.warning(f"Could not refresh experts, dht info key contains {response}, "
                 logger.warning(f"Could not refresh experts, dht info key contains {response}, "
                                f"will retry in {time_to_next_update}s")
                                f"will retry in {time_to_next_update}s")
+            if len(self.queue) == 0:
+                logger.warning("Update routine finished, but still no experts available.")
 
 
             self.last_update = get_dht_time()
             self.last_update = get_dht_time()
             self.update_finished.set()
             self.update_finished.set()
@@ -65,11 +67,14 @@ class ExpertBalancer:
         with self.lock:
         with self.lock:
             self.experts.store(uid, endpoint, expiration_time)
             self.experts.store(uid, endpoint, expiration_time)
             if uid not in self.uid_to_queue:
             if uid not in self.uid_to_queue:
+                logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
                 self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
                 self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
                 base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
                 base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
                 heap_entry = (base_load, random.random(), uid)
                 heap_entry = (base_load, random.random(), uid)
                 heapq.heappush(self.queue, heap_entry)
                 heapq.heappush(self.queue, heap_entry)
                 self.uid_to_queue[uid] = heap_entry
                 self.uid_to_queue[uid] = heap_entry
+            else:
+                logger.debug(f"Refreshing existing expert: {uid}, new expiration time = {expiration_time:.3f}.")
 
 
     def _ban_expert(self, uid: ExpertUID):
     def _ban_expert(self, uid: ExpertUID):
         with self.lock:
         with self.lock:
@@ -79,9 +84,10 @@ class ExpertBalancer:
             self.uid_to_queue.pop(uid, None)
             self.uid_to_queue.pop(uid, None)
             self.throughputs.pop(uid, None)
             self.throughputs.pop(uid, None)
             del self.experts[uid]
             del self.experts[uid]
+            logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
 
 
     @contextmanager
     @contextmanager
-    def lend_expert(self, task_size: int):
+    def use_another_expert(self, task_size: float) -> RemoteExpert:
         while True:
         while True:
             if len(self.queue) == 0:
             if len(self.queue) == 0:
                 self.update_finished.clear()
                 self.update_finished.clear()
@@ -109,8 +115,9 @@ class ExpertBalancer:
                 break
                 break
         try:
         try:
             with self.throughputs[uid].update_threadsafe(task_size):
             with self.throughputs[uid].update_threadsafe(task_size):
+                logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
                 yield RemoteExpert(uid, maybe_endpoint.value)
                 yield RemoteExpert(uid, maybe_endpoint.value)
-        except BaseException as e:
+        except BaseException:
             self._ban_expert(uid)
             self._ban_expert(uid)
             raise
             raise
 
 

+ 3 - 3
hivemind/optim/performance_ema.py

@@ -17,14 +17,14 @@ class PerformanceEMA:
         self.paused = paused
         self.paused = paused
         self.lock = Lock()
         self.lock = Lock()
 
 
-    def update(self, task_size: int, interval: float) -> float:
+    def update(self, task_size: float, interval: float) -> float:
         """
         """
         :param task_size: how many items were processed since last call
         :param task_size: how many items were processed since last call
         :param interval: optionally provide the time delta it took to process this task
         :param interval: optionally provide the time delta it took to process this task
         :returns: current estimate of performance (samples per second), but at most
         :returns: current estimate of performance (samples per second), but at most
         """
         """
-        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
         assert task_size > 0, f"Can't register processing {task_size} samples"
         assert task_size > 0, f"Can't register processing {task_size} samples"
+        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
         if not self.paused:
         if not self.paused:
             self.timestamp, old_timestamp = get_dht_time(), self.timestamp
             self.timestamp, old_timestamp = get_dht_time(), self.timestamp
             interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
             interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
@@ -48,7 +48,7 @@ class PerformanceEMA:
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
 
 
     @contextmanager
     @contextmanager
-    def update_threadsafe(self, task_size: int):
+    def update_threadsafe(self, task_size: float):
         """measure the EMA throughout of the code that runs inside the context"""
         """measure the EMA throughout of the code that runs inside the context"""
         start_timestamp = get_dht_time()
         start_timestamp = get_dht_time()
         yield
         yield