Max Ryabinin před 3 roky
rodič
revize
eb06507018

+ 1 - 1
hivemind/moe/client/balancer.py

@@ -135,7 +135,7 @@ class ExpertBalancer:
             if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
                 # response was too slow, choose the next expert and mark this one as slow
                 with self.lock:
-                    new_throughput = 1.5 * max(self.queue, key=itemgetter(1))[1]
+                    new_throughput = 1 / self.throughputs[uid].samples_per_second
                     new_heap_entry = (True, new_throughput, random.random(), uid)
                     heapq.heappush(self.queue, new_heap_entry)
                     self.uid_to_queue[uid] = new_heap_entry

+ 8 - 6
hivemind/optim/performance_ema.py

@@ -54,9 +54,11 @@ class PerformanceEMA:
     def update_threadsafe(self, task_size: float):
         """measure the EMA throughout of the code that runs inside the context"""
         start_timestamp = get_dht_time()
-        yield
-        with self.lock:
-            self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
-            # note: we define interval as such to support two distinct scenarios:
-            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
-            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls
+        try:
+            yield
+        finally:
+            with self.lock:
+                self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
+                # note: we define interval as such to support two distinct scenarios:
+                # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+                # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls