balancer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import heapq
  2. import random
  3. import threading
  4. from contextlib import contextmanager
  5. from typing import Dict, List, Tuple
  6. import time
  7. import grpc
  8. from grpc._channel import _InactiveRpcError
  9. from hivemind.dht import DHT
  10. from hivemind.moe.client.expert import RemoteExpert
  11. from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
  12. from hivemind.optim.performance_ema import PerformanceEMA
  13. from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
  14. logger = get_logger(__name__)
  15. class ExpertBalancer:
  16. def __init__(
  17. self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
  18. sleep_timeout: float = 5.0, **kwargs
  19. ):
  20. self.dht, self.key = dht, key
  21. self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
  22. self.experts = TimedStorage[ExpertUID, Endpoint]()
  23. self.blacklist = TimedStorage[ExpertUID, type(None)]()
  24. self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
  25. self.queue: List[Tuple[float, float, ExpertUID]] = []
  26. self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
  27. self.lock = threading.Lock()
  28. self.is_alive = threading.Event()
  29. self.is_alive.set()
  30. self.update_trigger, self.update_finished = threading.Event(), threading.Event()
  31. self.update_period, self.last_update = update_period, get_dht_time()
  32. self.sleep_timeout = sleep_timeout
  33. self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
  34. self.update_thread.start()
  35. def update_experts_in_background(self):
  36. while self.is_alive.is_set():
  37. time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
  38. try:
  39. self.update_trigger.wait(timeout=time_to_next_update)
  40. # update triggered by main thread
  41. except TimeoutError:
  42. pass # update triggered by refresh_period
  43. self.update_trigger.clear()
  44. response = self.dht.get(self.key, latest=True)
  45. if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
  46. for index, expert_info in response.value.items():
  47. try:
  48. (uid, endpoint), expiration_time = expert_info
  49. maybe_banned = self.blacklist.get(uid)
  50. if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
  51. self._add_expert(uid, endpoint, expiration_time)
  52. else:
  53. logger.debug(f"Not adding expert {uid} (blacklisted).")
  54. except Exception as e:
  55. logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
  56. else:
  57. logger.warning(
  58. f"Could not refresh experts, dht info key contains {response}, "
  59. f"will retry in {time_to_next_update}s"
  60. )
  61. if len(self.queue) == 0:
  62. logger.warning("Update routine finished, but still no experts available.")
  63. time.sleep(self.sleep_timeout)
  64. self.last_update = get_dht_time()
  65. self.update_finished.set()
  66. def _add_expert(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration):
  67. with self.lock:
  68. self.experts.store(uid, endpoint, expiration_time)
  69. if uid not in self.uid_to_queue:
  70. logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
  71. self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
  72. # ensure that added experts are evaluated by placing them near the top of the queue
  73. base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
  74. heap_entry = (base_load, random.random(), uid)
  75. heapq.heappush(self.queue, heap_entry)
  76. self.uid_to_queue[uid] = heap_entry
  77. else:
  78. logger.debug(f"Refreshing existing expert: {uid}, new expiration time = {expiration_time:.3f}.")
  79. def _ban_expert(self, uid: ExpertUID):
  80. with self.lock:
  81. maybe_expert = self.experts.get(uid)
  82. expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
  83. self.blacklist.store(uid, None, expiration_time)
  84. self.uid_to_queue.pop(uid, None)
  85. self.throughputs.pop(uid, None)
  86. del self.experts[uid]
  87. logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
  88. @contextmanager
  89. def use_another_expert(self, task_size: float) -> RemoteExpert:
  90. while True:
  91. if len(self.queue) == 0:
  92. self.update_finished.clear()
  93. self.update_trigger.set()
  94. self.update_finished.wait()
  95. continue
  96. with self.lock:
  97. current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
  98. maybe_endpoint = self.experts.get(uid)
  99. if maybe_endpoint is None:
  100. # remove expired expert from queue
  101. self.uid_to_queue.pop(uid, None)
  102. self.throughputs.pop(uid, None)
  103. if self.uid_to_queue.get(uid) != heap_entry:
  104. logger.debug(f"Skipping expert {uid} "
  105. f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
  106. continue # skip uids that are banned or expired
  107. if self.throughputs[uid].num_updates != 0:
  108. expected_time_taken = task_size / self.throughputs[uid].samples_per_second
  109. else:
  110. expected_time_taken = self.initial_throughput * task_size
  111. new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
  112. heapq.heappush(self.queue, new_heap_entry)
  113. self.uid_to_queue[uid] = new_heap_entry
  114. break
  115. try:
  116. with self.throughputs[uid].update_threadsafe(task_size):
  117. logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
  118. yield RemoteExpert(uid, maybe_endpoint.value)
  119. logger.debug(f"Finished using expert {uid}.")
  120. except _InactiveRpcError as error:
  121. if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
  122. # response was too slow, choose the next expert
  123. raise
  124. else:
  125. self._ban_expert(uid)
  126. raise
  127. except BaseException:
  128. self._ban_expert(uid)
  129. raise
  130. def shutdown(self):
  131. self.is_alive.clear()
  132. self.update_finished.clear()
  133. self.update_trigger.set()
  134. self.update_finished.wait()