balancer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import heapq
  2. import random
  3. import threading
  4. import time
  5. from contextlib import contextmanager
  6. from operator import itemgetter
  7. from typing import Dict, List, Tuple
  8. import grpc
  9. from grpc._channel import _InactiveRpcError
  10. from hivemind.dht import DHT
  11. from hivemind.moe.client.expert import RemoteExpert
  12. from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
  13. from hivemind.optim.performance_ema import PerformanceEMA
  14. from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
  15. logger = get_logger(__name__)
  16. QueueItem = Tuple[bool, float, float, ExpertUID]
  17. class ExpertBalancer:
  18. def __init__(
  19. self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, sleep_timeout: float = 5.0, **kwargs
  20. ):
  21. self.dht, self.key = dht, key
  22. self.ema_kwargs = kwargs
  23. self.experts = TimedStorage[ExpertUID, Endpoint]()
  24. self.blacklist = TimedStorage[ExpertUID, type(None)]()
  25. self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
  26. # had_forward_pass, throughput, random, uid
  27. self.queue: List[QueueItem] = []
  28. self.uid_to_queue: Dict[ExpertUID, QueueItem] = {}
  29. self.lock = threading.Lock()
  30. self.is_alive = threading.Event()
  31. self.is_alive.set()
  32. self.update_trigger, self.update_finished = threading.Event(), threading.Event()
  33. self.update_period, self.last_update = update_period, get_dht_time()
  34. self.sleep_timeout = sleep_timeout
  35. self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
  36. self.update_thread.start()
  37. def update_experts_in_background(self):
  38. while self.is_alive.is_set():
  39. time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
  40. try:
  41. self.update_trigger.wait(timeout=time_to_next_update)
  42. # update triggered by main thread
  43. except TimeoutError:
  44. pass # update triggered by refresh_period
  45. self.update_trigger.clear()
  46. response = self.dht.get(self.key, latest=True)
  47. if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
  48. for index, expert_info in response.value.items():
  49. try:
  50. (uid, endpoint), expiration_time = expert_info
  51. maybe_banned = self.blacklist.get(uid)
  52. if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
  53. self._add_expert(uid, endpoint, expiration_time)
  54. else:
  55. logger.debug(f"Not adding expert {uid} (blacklisted).")
  56. except Exception as e:
  57. logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
  58. else:
  59. logger.warning(
  60. f"Could not refresh experts, dht info key contains {response}, "
  61. f"will retry in {time_to_next_update}s"
  62. )
  63. if len(self.queue) == 0:
  64. logger.warning("Update routine finished, but still no experts available.")
  65. time.sleep(self.sleep_timeout)
  66. self.last_update = get_dht_time()
  67. self.update_finished.set()
  68. def _add_expert(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration):
  69. with self.lock:
  70. self.experts.store(uid, endpoint, expiration_time)
  71. if uid not in self.uid_to_queue:
  72. logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
  73. self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
  74. heap_entry = (False, 0.0, random.random(), uid)
  75. heapq.heappush(self.queue, heap_entry)
  76. self.uid_to_queue[uid] = heap_entry
  77. else:
  78. logger.debug(f"Refreshing existing expert: {uid}, new expiration time = {expiration_time:.3f}.")
  79. def _ban_expert(self, uid: ExpertUID):
  80. with self.lock:
  81. maybe_expert = self.experts.get(uid)
  82. expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
  83. self.blacklist.store(uid, None, expiration_time)
  84. self.uid_to_queue.pop(uid, None)
  85. self.throughputs.pop(uid, None)
  86. del self.experts[uid]
  87. logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
  88. @contextmanager
  89. def use_another_expert(self, task_size: float) -> RemoteExpert:
  90. while True:
  91. if len(self.queue) == 0:
  92. self.update_finished.clear()
  93. self.update_trigger.set()
  94. self.update_finished.wait()
  95. continue
  96. with self.lock:
  97. logger.debug(f"Getting a new expert, queue state: {self.queue}")
  98. *_, uid = heap_entry = heapq.heappop(self.queue)
  99. maybe_endpoint = self.experts.get(uid)
  100. if maybe_endpoint is None:
  101. # remove expired expert from queue
  102. self.uid_to_queue.pop(uid, None)
  103. self.throughputs.pop(uid, None)
  104. if self.uid_to_queue.get(uid) != heap_entry:
  105. logger.debug(f"Skipping expert {uid} "
  106. f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
  107. continue # skip uids that are banned or expired
  108. break
  109. try:
  110. with self.throughputs[uid].update_threadsafe(task_size):
  111. logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
  112. yield RemoteExpert(uid, maybe_endpoint.value)
  113. logger.debug(f"Finished using expert {uid}.")
  114. new_throughput = 1 / self.throughputs[uid].samples_per_second
  115. new_heap_entry = (True, new_throughput, random.random(), uid)
  116. with self.lock:
  117. heapq.heappush(self.queue, new_heap_entry)
  118. self.uid_to_queue[uid] = new_heap_entry
  119. except _InactiveRpcError as error:
  120. if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
  121. # response was too slow, choose the next expert and mark this one as slow
  122. with self.lock:
  123. new_throughput = 1.5 * max(self.queue, key=itemgetter(1))[1]
  124. new_heap_entry = (True, new_throughput, random.random(), uid)
  125. heapq.heappush(self.queue, new_heap_entry)
  126. self.uid_to_queue[uid] = new_heap_entry
  127. raise
  128. else:
  129. self._ban_expert(uid)
  130. raise
  131. except BaseException:
  132. self._ban_expert(uid)
  133. raise
  134. def shutdown(self):
  135. self.is_alive.clear()
  136. self.update_finished.clear()
  137. self.update_trigger.set()
  138. self.update_finished.wait()