浏览代码

Revert "Test bilevel queue"

This reverts commit 11848f6bc194c416bd13e80cc3ebb1e034e5c267.
Max Ryabinin 3 年之前
父节点
当前提交
77919315c3
共有 1 个文件被更改,包括 19 次插入24 次删除
  1. 19 24
      hivemind/moe/client/balancer.py

+ 19 - 24
hivemind/moe/client/balancer.py

@@ -1,10 +1,9 @@
 import heapq
 import random
 import threading
-import time
 from contextlib import contextmanager
-from operator import itemgetter
 from typing import Dict, List, Tuple
+import time
 
 import grpc
 from grpc._channel import _InactiveRpcError
@@ -17,21 +16,19 @@ from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpir
 
 logger = get_logger(__name__)
 
-QueueItem = Tuple[bool, float, float, ExpertUID]
-
 
 class ExpertBalancer:
     def __init__(
-        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, sleep_timeout: float = 5.0, **kwargs
+        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+        sleep_timeout: float = 5.0, **kwargs
     ):
         self.dht, self.key = dht, key
-        self.ema_kwargs = kwargs
+        self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
         self.experts = TimedStorage[ExpertUID, Endpoint]()
         self.blacklist = TimedStorage[ExpertUID, type(None)]()
         self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
-        # had_forward_pass, throughput, random, uid
-        self.queue: List[QueueItem] = []
-        self.uid_to_queue: Dict[ExpertUID, QueueItem] = {}
+        self.queue: List[Tuple[float, float, ExpertUID]] = []
+        self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
         self.lock = threading.Lock()
         self.is_alive = threading.Event()
         self.is_alive.set()
@@ -82,7 +79,9 @@ class ExpertBalancer:
             if uid not in self.uid_to_queue:
                 logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
                 self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
-                heap_entry = (False, 0.0, random.random(), uid)
+                # ensure that added experts are evaluated by placing them near the top of the queue
+                base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
+                heap_entry = (base_load, random.random(), uid)
                 heapq.heappush(self.queue, heap_entry)
                 self.uid_to_queue[uid] = heap_entry
             else:
@@ -109,7 +108,7 @@ class ExpertBalancer:
 
             with self.lock:
                 logger.debug(f"Getting a new expert, queue state: {self.queue}")
-                *_, uid = heap_entry = heapq.heappop(self.queue)
+                current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
                 maybe_endpoint = self.experts.get(uid)
                 if maybe_endpoint is None:
                     # remove expired expert from queue
@@ -119,27 +118,23 @@ class ExpertBalancer:
                     logger.debug(f"Skipping expert {uid} "
                                  f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
                     continue  # skip uids that are banned or expired
+
+                if self.throughputs[uid].num_updates != 0:
+                    expected_time_taken = task_size / self.throughputs[uid].samples_per_second
+                else:
+                    expected_time_taken = self.initial_throughput * task_size
+                new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
+                heapq.heappush(self.queue, new_heap_entry)
+                self.uid_to_queue[uid] = new_heap_entry
                 break
         try:
             with self.throughputs[uid].update_threadsafe(task_size):
                 logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
                 yield RemoteExpert(uid, maybe_endpoint.value)
                 logger.debug(f"Finished using expert {uid}.")
-
-            new_throughput = 1 / self.throughputs[uid].samples_per_second
-            new_heap_entry = (True, new_throughput, random.random(), uid)
-            with self.lock:
-                heapq.heappush(self.queue, new_heap_entry)
-                self.uid_to_queue[uid] = new_heap_entry
         except _InactiveRpcError as error:
             if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
-                # response was too slow, choose the next expert and mark this one as slow
-                with self.lock:
-                    new_throughput = max(self.queue, key=itemgetter(1))[1]
-                    new_heap_entry = (True, new_throughput, random.random(), uid)
-                    heapq.heappush(self.queue, new_heap_entry)
-                    self.uid_to_queue[uid] = new_heap_entry
-
+                # response was too slow, choose the next expert
                 raise
             else:
                 self._ban_expert(uid)