load_balancer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import heapq
  2. import random
  3. import threading
  4. from contextlib import contextmanager
  5. from typing import Dict, List, Tuple
  6. from hivemind import RemoteExpert, TimedStorage, PeerID
  7. from hivemind.dht import DHT
  8. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  9. from hivemind.moe.expert_uid import ExpertPrefix, ExpertUID, ExpertInfo
  10. from hivemind.utils.performance_ema import PerformanceEMA
  11. from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
  12. logger = get_logger(__name__)
  13. class LoadBalancer:
  14. def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
  15. **kwargs):
  16. self.dht, self.key = dht, key
  17. self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
  18. self.experts = TimedStorage[ExpertUID, PeerID]()
  19. self.blacklist = TimedStorage[ExpertUID, type(None)]()
  20. self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
  21. self.queue: List[Tuple[float, float, ExpertUID]] = []
  22. self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
  23. self.lock = threading.Lock()
  24. self.is_alive = threading.Event()
  25. self.is_alive.set()
  26. self.update_trigger, self.update_finished = threading.Event(), threading.Event()
  27. self.update_period, self.last_update = update_period, get_dht_time()
  28. self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
  29. self.update_thread.start()
  30. self._p2p = RemoteExpertWorker.run_coroutine(self.dht.replicate_p2p())
  31. def update_experts_in_background(self):
  32. while self.is_alive.is_set():
  33. time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
  34. try:
  35. self.update_trigger.wait(timeout=time_to_next_update)
  36. # update triggered by main thread
  37. except TimeoutError:
  38. pass # update triggered by refresh_period
  39. self.update_trigger.clear()
  40. response = self.dht.get(self.key, latest=True)
  41. if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
  42. for index, expert_info in response.value.items():
  43. try:
  44. (expert_uid, peer_id), expiration_time = expert_info
  45. maybe_banned = self.blacklist.get(expert_uid)
  46. if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
  47. self._add_expert(expert_uid, peer_id, expiration_time)
  48. else:
  49. logger.debug(f"Not adding expert {expert_uid} (blacklisted).")
  50. except Exception as e:
  51. logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
  52. else:
  53. logger.warning(f"Could not refresh experts, dht info key contains {response}, "
  54. f"will retry in {time_to_next_update}s")
  55. if len(self.queue) == 0:
  56. logger.warning("Update routine finished, but still no experts available.")
  57. self.last_update = get_dht_time()
  58. self.update_finished.set()
  59. def _add_expert(self, uid: ExpertUID, peer_id: PeerID, expiration_time: DHTExpiration):
  60. with self.lock:
  61. self.experts.store(uid, peer_id, expiration_time)
  62. if uid not in self.uid_to_queue:
  63. logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
  64. self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
  65. base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
  66. heap_entry = (base_load, random.random(), uid)
  67. heapq.heappush(self.queue, heap_entry)
  68. self.uid_to_queue[uid] = heap_entry
  69. else:
  70. logger.debug(f"Refreshing existing module: {uid}, new expiration time = {expiration_time:.3f}.")
  71. def _ban_expert(self, uid: ExpertUID):
  72. with self.lock:
  73. maybe_expert = self.experts.get(uid)
  74. expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
  75. self.blacklist.store(uid, None, expiration_time)
  76. self.uid_to_queue.pop(uid, None)
  77. self.throughputs.pop(uid, None)
  78. del self.experts[uid]
  79. logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
  80. @contextmanager
  81. def use_another_expert(self, task_size: float) -> RemoteExpert:
  82. while True:
  83. if len(self.queue) == 0:
  84. self.update_finished.clear()
  85. self.update_trigger.set()
  86. self.update_finished.wait()
  87. continue
  88. with self.lock:
  89. current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
  90. maybe_peer_id = self.experts.get(uid)
  91. if maybe_peer_id is None:
  92. # remove expired expert from queue
  93. self.uid_to_queue.pop(uid, None)
  94. self.throughputs.pop(uid, None)
  95. if self.uid_to_queue.get(uid) != heap_entry:
  96. continue # skip uids that are banned or expired
  97. if self.throughputs[uid].num_updates != 0:
  98. expected_time_taken = task_size / self.throughputs[uid].samples_per_second
  99. else:
  100. expected_time_taken = self.initial_throughput * task_size
  101. new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
  102. heapq.heappush(self.queue, new_heap_entry)
  103. self.uid_to_queue[uid] = new_heap_entry
  104. break
  105. try:
  106. with self.throughputs[uid].update_threadsafe(task_size):
  107. logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
  108. yield RemoteExpert(ExpertInfo(uid, PeerID.from_base58(maybe_peer_id.value)), self._p2p)
  109. except BaseException:
  110. self._ban_expert(uid)
  111. raise
  112. def shutdown(self):
  113. self.is_alive.clear()
  114. self.update_finished.clear()
  115. self.update_trigger.set()
  116. self.update_finished.wait()