123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import heapq
- import random
- import threading
- from contextlib import contextmanager
- from typing import Dict, List, Tuple
- from hivemind import RemoteExpert, TimedStorage, PeerID
- from hivemind.dht import DHT
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
- from hivemind.moe.expert_uid import ExpertPrefix, ExpertUID, ExpertInfo
- from hivemind.utils.performance_ema import PerformanceEMA
- from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
- logger = get_logger(__name__)
- class LoadBalancer:
- def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
- **kwargs):
- self.dht, self.key = dht, key
- self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
- self.experts = TimedStorage[ExpertUID, PeerID]()
- self.blacklist = TimedStorage[ExpertUID, type(None)]()
- self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
- self.queue: List[Tuple[float, float, ExpertUID]] = []
- self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
- self.lock = threading.Lock()
- self.is_alive = threading.Event()
- self.is_alive.set()
- self.update_trigger, self.update_finished = threading.Event(), threading.Event()
- self.update_period, self.last_update = update_period, get_dht_time()
- self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
- self.update_thread.start()
- self._p2p = RemoteExpertWorker.run_coroutine(self.dht.replicate_p2p())
- def update_experts_in_background(self):
- while self.is_alive.is_set():
- time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
- try:
- self.update_trigger.wait(timeout=time_to_next_update)
- # update triggered by main thread
- except TimeoutError:
- pass # update triggered by refresh_period
- self.update_trigger.clear()
- response = self.dht.get(self.key, latest=True)
- if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
- for index, expert_info in response.value.items():
- try:
- (expert_uid, peer_id), expiration_time = expert_info
- maybe_banned = self.blacklist.get(expert_uid)
- if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
- self._add_expert(expert_uid, peer_id, expiration_time)
- else:
- logger.debug(f"Not adding expert {expert_uid} (blacklisted).")
- except Exception as e:
- logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
- else:
- logger.warning(f"Could not refresh experts, dht info key contains {response}, "
- f"will retry in {time_to_next_update}s")
- if len(self.queue) == 0:
- logger.warning("Update routine finished, but still no experts available.")
- self.last_update = get_dht_time()
- self.update_finished.set()
- def _add_expert(self, uid: ExpertUID, peer_id: PeerID, expiration_time: DHTExpiration):
- with self.lock:
- self.experts.store(uid, peer_id, expiration_time)
- if uid not in self.uid_to_queue:
- logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
- self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
- base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
- heap_entry = (base_load, random.random(), uid)
- heapq.heappush(self.queue, heap_entry)
- self.uid_to_queue[uid] = heap_entry
- else:
- logger.debug(f"Refreshing existing module: {uid}, new expiration time = {expiration_time:.3f}.")
- def _ban_expert(self, uid: ExpertUID):
- with self.lock:
- maybe_expert = self.experts.get(uid)
- expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
- self.blacklist.store(uid, None, expiration_time)
- self.uid_to_queue.pop(uid, None)
- self.throughputs.pop(uid, None)
- del self.experts[uid]
- logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
- @contextmanager
- def use_another_expert(self, task_size: float) -> RemoteExpert:
- while True:
- if len(self.queue) == 0:
- self.update_finished.clear()
- self.update_trigger.set()
- self.update_finished.wait()
- continue
- with self.lock:
- current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
- maybe_peer_id = self.experts.get(uid)
- if maybe_peer_id is None:
- # remove expired expert from queue
- self.uid_to_queue.pop(uid, None)
- self.throughputs.pop(uid, None)
- if self.uid_to_queue.get(uid) != heap_entry:
- continue # skip uids that are banned or expired
- if self.throughputs[uid].num_updates != 0:
- expected_time_taken = task_size / self.throughputs[uid].samples_per_second
- else:
- expected_time_taken = self.initial_throughput * task_size
- new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
- heapq.heappush(self.queue, new_heap_entry)
- self.uid_to_queue[uid] = new_heap_entry
- break
- try:
- with self.throughputs[uid].update_threadsafe(task_size):
- logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
- yield RemoteExpert(ExpertInfo(uid, PeerID.from_base58(maybe_peer_id.value)), self._p2p)
- except BaseException:
- self._ban_expert(uid)
- raise
- def shutdown(self):
- self.is_alive.clear()
- self.update_finished.clear()
- self.update_trigger.set()
- self.update_finished.wait()
|