|
@@ -1,10 +1,9 @@
|
|
import heapq
|
|
import heapq
|
|
import random
|
|
import random
|
|
import threading
|
|
import threading
|
|
-import time
|
|
|
|
from contextlib import contextmanager
|
|
from contextlib import contextmanager
|
|
-from operator import itemgetter
|
|
|
|
from typing import Dict, List, Tuple
|
|
from typing import Dict, List, Tuple
|
|
|
|
+import time
|
|
|
|
|
|
import grpc
|
|
import grpc
|
|
from grpc._channel import _InactiveRpcError
|
|
from grpc._channel import _InactiveRpcError
|
|
@@ -17,21 +16,19 @@ from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpir
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
-QueueItem = Tuple[bool, float, float, ExpertUID]
|
|
|
|
-
|
|
|
|
|
|
|
|
class ExpertBalancer:
|
|
class ExpertBalancer:
|
|
def __init__(
|
|
def __init__(
|
|
- self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, sleep_timeout: float = 5.0, **kwargs
|
|
|
|
|
|
+ self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
|
|
|
|
+ sleep_timeout: float = 5.0, **kwargs
|
|
):
|
|
):
|
|
self.dht, self.key = dht, key
|
|
self.dht, self.key = dht, key
|
|
- self.ema_kwargs = kwargs
|
|
|
|
|
|
+ self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
|
|
self.experts = TimedStorage[ExpertUID, Endpoint]()
|
|
self.experts = TimedStorage[ExpertUID, Endpoint]()
|
|
self.blacklist = TimedStorage[ExpertUID, type(None)]()
|
|
self.blacklist = TimedStorage[ExpertUID, type(None)]()
|
|
self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
|
|
self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
|
|
- # had_forward_pass, throughput, random, uid
|
|
|
|
- self.queue: List[QueueItem] = []
|
|
|
|
- self.uid_to_queue: Dict[ExpertUID, QueueItem] = {}
|
|
|
|
|
|
+ self.queue: List[Tuple[float, float, ExpertUID]] = []
|
|
|
|
+ self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
|
|
self.lock = threading.Lock()
|
|
self.lock = threading.Lock()
|
|
self.is_alive = threading.Event()
|
|
self.is_alive = threading.Event()
|
|
self.is_alive.set()
|
|
self.is_alive.set()
|
|
@@ -82,7 +79,9 @@ class ExpertBalancer:
|
|
if uid not in self.uid_to_queue:
|
|
if uid not in self.uid_to_queue:
|
|
logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
|
|
logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
|
|
self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
|
|
self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
|
|
- heap_entry = (False, 0.0, random.random(), uid)
|
|
|
|
|
|
+ # ensure that added experts are evaluated by placing them near the top of the queue
|
|
|
|
+ base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
|
|
|
|
+ heap_entry = (base_load, random.random(), uid)
|
|
heapq.heappush(self.queue, heap_entry)
|
|
heapq.heappush(self.queue, heap_entry)
|
|
self.uid_to_queue[uid] = heap_entry
|
|
self.uid_to_queue[uid] = heap_entry
|
|
else:
|
|
else:
|
|
@@ -109,7 +108,7 @@ class ExpertBalancer:
|
|
|
|
|
|
with self.lock:
|
|
with self.lock:
|
|
logger.debug(f"Getting a new expert, queue state: {self.queue}")
|
|
logger.debug(f"Getting a new expert, queue state: {self.queue}")
|
|
- *_, uid = heap_entry = heapq.heappop(self.queue)
|
|
|
|
|
|
+ current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
|
|
maybe_endpoint = self.experts.get(uid)
|
|
maybe_endpoint = self.experts.get(uid)
|
|
if maybe_endpoint is None:
|
|
if maybe_endpoint is None:
|
|
# remove expired expert from queue
|
|
# remove expired expert from queue
|
|
@@ -119,27 +118,23 @@ class ExpertBalancer:
|
|
logger.debug(f"Skipping expert {uid} "
|
|
logger.debug(f"Skipping expert {uid} "
|
|
f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
|
|
f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
|
|
continue # skip uids that are banned or expired
|
|
continue # skip uids that are banned or expired
|
|
|
|
+
|
|
|
|
+ if self.throughputs[uid].num_updates != 0:
|
|
|
|
+ expected_time_taken = task_size / self.throughputs[uid].samples_per_second
|
|
|
|
+ else:
|
|
|
|
+ expected_time_taken = self.initial_throughput * task_size
|
|
|
|
+ new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
|
|
|
|
+ heapq.heappush(self.queue, new_heap_entry)
|
|
|
|
+ self.uid_to_queue[uid] = new_heap_entry
|
|
break
|
|
break
|
|
try:
|
|
try:
|
|
with self.throughputs[uid].update_threadsafe(task_size):
|
|
with self.throughputs[uid].update_threadsafe(task_size):
|
|
logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
|
|
logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
|
|
yield RemoteExpert(uid, maybe_endpoint.value)
|
|
yield RemoteExpert(uid, maybe_endpoint.value)
|
|
logger.debug(f"Finished using expert {uid}.")
|
|
logger.debug(f"Finished using expert {uid}.")
|
|
-
|
|
|
|
- new_throughput = 1 / self.throughputs[uid].samples_per_second
|
|
|
|
- new_heap_entry = (True, new_throughput, random.random(), uid)
|
|
|
|
- with self.lock:
|
|
|
|
- heapq.heappush(self.queue, new_heap_entry)
|
|
|
|
- self.uid_to_queue[uid] = new_heap_entry
|
|
|
|
except _InactiveRpcError as error:
|
|
except _InactiveRpcError as error:
|
|
if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
- # response was too slow, choose the next expert and mark this one as slow
|
|
|
|
- with self.lock:
|
|
|
|
- new_throughput = max(self.queue, key=itemgetter(1))[1]
|
|
|
|
- new_heap_entry = (True, new_throughput, random.random(), uid)
|
|
|
|
- heapq.heappush(self.queue, new_heap_entry)
|
|
|
|
- self.uid_to_queue[uid] = new_heap_entry
|
|
|
|
-
|
|
|
|
|
|
+ # response was too slow, choose the next expert
|
|
raise
|
|
raise
|
|
else:
|
|
else:
|
|
self._ban_expert(uid)
|
|
self._ban_expert(uid)
|