Browse Source

load balancer that works

justheuristic 4 years ago
parent
commit
cba8939c66
2 changed files with 128 additions and 3 deletions
  1. 121 0
      hivemind/moe/client/balancer.py
  2. 7 3
      hivemind/utils/timed_storage.py

+ 121 - 0
hivemind/moe/client/balancer.py

@@ -0,0 +1,121 @@
+import heapq
+import random
+import threading
+from contextlib import contextmanager
+from typing import Tuple, List, Dict
+
+from hivemind import TimedStorage, Endpoint, RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.server.expert_uid import ExpertPrefix
+from hivemind.moe.server.expert_uid import ExpertUID
+from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_logger, get_dht_time
+
+logger = get_logger(__name__)
+
+
+class ExpertBalancer:
+    def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+                 **kwargs):
+        self.dht, self.key = dht, key
+        self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
+        self.experts = TimedStorage[ExpertUID, Endpoint]()
+        self.blacklist = TimedStorage[ExpertUID, type(None)]()
+        self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
+        self.queue: List[Tuple[float, float, ExpertUID]] = []
+        self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
+        self.lock = threading.Lock()
+        self.is_alive = threading.Event()
+        self.is_alive.set()
+        self.update_trigger, self.update_finished = threading.Event(), threading.Event()
+        self.update_period, self.last_update = update_period, get_dht_time()
+        self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
+        self.update_thread.start()
+
+    def update_experts_in_background(self):
+        while self.is_alive.is_set():
+            time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
+            try:
+                self.update_trigger.wait(timeout=time_to_next_update)
+                # update triggered by main thread
+            except TimeoutError:
+                pass  # update triggered by refresh_period
+
+            self.update_trigger.clear()
+            response = self.dht.get(self.key, latest=True)
+            if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
+                for index, expert_info in response.value.items():
+                    try:
+                        (uid, endpoint), expiration_time = expert_info
+
+                        maybe_banned = self.blacklist.get(uid)
+                        if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
+                            self._add_expert(uid, endpoint, expiration_time)
+
+                    except Exception as e:
+                        logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
+            else:
+                logger.warning(f"Could not refresh experts, dht info key contains {response}, "
+                               f"will retry in {time_to_next_update}s")
+
+            self.last_update = get_dht_time()
+            self.update_finished.set()
+
+    def _add_expert(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration):
+        with self.lock:
+            self.experts.store(uid, endpoint, expiration_time)
+            if uid not in self.uid_to_queue:
+                self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
+                base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
+                heap_entry = (base_load, random.random(), uid)
+                heapq.heappush(self.queue, heap_entry)
+                self.uid_to_queue[uid] = heap_entry
+
+    def _ban_expert(self, uid: ExpertUID):
+        with self.lock:
+            maybe_expert = self.experts.get(uid)
+            expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
+            self.blacklist.store(uid, None, expiration_time)
+            self.uid_to_queue.pop(uid, None)
+            self.throughputs.pop(uid, None)
+            del self.experts[uid]
+
+    @contextmanager
+    def lend_expert(self, task_size: int):
+        while True:
+            if len(self.queue) == 0:
+                self.update_finished.clear()
+                self.update_trigger.set()
+                self.update_finished.wait()
+                continue
+
+            with self.lock:
+                current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
+                maybe_endpoint = self.experts.get(uid)
+                if maybe_endpoint is None:
+                    # remove expired expert from queue
+                    self.uid_to_queue.pop(uid, None)
+                    self.throughputs.pop(uid, None)
+                if self.uid_to_queue.get(uid) != heap_entry:
+                    continue  # skip uids that are banned or expired
+
+                if self.throughputs[uid].num_updates != 0:
+                    expected_time_taken = task_size / self.throughputs[uid].samples_per_second
+                else:
+                    expected_time_taken = self.initial_throughput * task_size
+                new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
+                heapq.heappush(self.queue, new_heap_entry)
+                self.uid_to_queue[uid] = new_heap_entry
+                break
+        try:
+            with self.throughputs[uid].update_threadsafe(task_size):
+                yield RemoteExpert(uid, maybe_endpoint.value)
+        except BaseException as e:
+            self._ban_expert(uid)
+            raise
+
+    def shutdown(self):
+        self.is_alive.clear()
+        self.update_finished.clear()
+        self.update_trigger.set()
+        self.update_finished.wait()

+ 7 - 3
hivemind/utils/timed_storage.py

@@ -113,6 +113,12 @@ class TimedStorage(Generic[KeyType, ValueType]):
         self.key_to_heap.clear()
         self.expiration_heap.clear()
 
+    def discard(self, key: KeyType):
+        """If storage contains key, drop the corresponding item, otherwise do nothing"""
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
     def __contains__(self, key: KeyType):
         self._remove_outdated()
         return key in self.data
@@ -122,9 +128,7 @@ class TimedStorage(Generic[KeyType, ValueType]):
         return len(self.data)
 
     def __delitem__(self, key: KeyType):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+        self.discard(key)
 
     def __bool__(self):
         return bool(self.data)