Max Ryabinin 3 лет назад
Родитель
Сommit
11848f6bc1
1 измененных файлов с 24 добавлено и 19 удалено
  1. 24 19
      hivemind/moe/client/balancer.py

+ 24 - 19
hivemind/moe/client/balancer.py

@@ -1,9 +1,10 @@
 import heapq
 import random
 import threading
+import time
 from contextlib import contextmanager
+from operator import itemgetter
 from typing import Dict, List, Tuple
-import time
 
 import grpc
 from grpc._channel import _InactiveRpcError
@@ -16,19 +17,21 @@ from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpir
 
 logger = get_logger(__name__)
 
+QueueItem = Tuple[bool, float, float, ExpertUID]
+
 
 class ExpertBalancer:
     def __init__(
-        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
-        sleep_timeout: float = 5.0, **kwargs
+        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, sleep_timeout: float = 5.0, **kwargs
     ):
         self.dht, self.key = dht, key
-        self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
+        self.ema_kwargs = kwargs
         self.experts = TimedStorage[ExpertUID, Endpoint]()
         self.blacklist = TimedStorage[ExpertUID, type(None)]()
         self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
-        self.queue: List[Tuple[float, float, ExpertUID]] = []
-        self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
+        # had_forward_pass, throughput, random, uid
+        self.queue: List[QueueItem] = []
+        self.uid_to_queue: Dict[ExpertUID, QueueItem] = {}
         self.lock = threading.Lock()
         self.is_alive = threading.Event()
         self.is_alive.set()
@@ -79,9 +82,7 @@ class ExpertBalancer:
             if uid not in self.uid_to_queue:
                 logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
                 self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
-                # ensure that added experts are evaluated by placing them near the top of the queue
-                base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
-                heap_entry = (base_load, random.random(), uid)
+                heap_entry = (False, 0.0, random.random(), uid)
                 heapq.heappush(self.queue, heap_entry)
                 self.uid_to_queue[uid] = heap_entry
             else:
@@ -108,7 +109,7 @@ class ExpertBalancer:
 
             with self.lock:
                 logger.debug(f"Getting a new expert, queue state: {self.queue}")
-                current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
+                *_, uid = heap_entry = heapq.heappop(self.queue)
                 maybe_endpoint = self.experts.get(uid)
                 if maybe_endpoint is None:
                     # remove expired expert from queue
@@ -118,23 +119,27 @@ class ExpertBalancer:
                     logger.debug(f"Skipping expert {uid} "
                                  f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
                     continue  # skip uids that are banned or expired
-
-                if self.throughputs[uid].num_updates != 0:
-                    expected_time_taken = task_size / self.throughputs[uid].samples_per_second
-                else:
-                    expected_time_taken = self.initial_throughput * task_size
-                new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
-                heapq.heappush(self.queue, new_heap_entry)
-                self.uid_to_queue[uid] = new_heap_entry
                 break
         try:
             with self.throughputs[uid].update_threadsafe(task_size):
                 logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
                 yield RemoteExpert(uid, maybe_endpoint.value)
                 logger.debug(f"Finished using expert {uid}.")
+
+            new_throughput = 1 / self.throughputs[uid].samples_per_second
+            new_heap_entry = (True, new_throughput, random.random(), uid)
+            with self.lock:
+                heapq.heappush(self.queue, new_heap_entry)
+                self.uid_to_queue[uid] = new_heap_entry
         except _InactiveRpcError as error:
             if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
-                # response was too slow, choose the next expert
+                # response was too slow, choose the next expert and mark this one as slow
+                with self.lock:
+                    new_throughput = max(self.queue, key=itemgetter(1))[1]
+                    new_heap_entry = (True, new_throughput, random.random(), uid)
+                    heapq.heappush(self.queue, new_heap_entry)
+                    self.uid_to_queue[uid] = new_heap_entry
+
                 raise
             else:
                 self._ban_expert(uid)