Quellcode durchsuchen

Revert "No 1.5 multiplier"

This reverts commit eb0650701859aca6ca63ad30c614da20732c7c4f.
Max Ryabinin vor 3 Jahren
Ursprung
Commit
e17610f8cf
2 geänderte Dateien mit 7 neuen und 9 gelöschten Zeilen
  1. 1 1
      hivemind/moe/client/balancer.py
  2. 6 8
      hivemind/optim/performance_ema.py

+ 1 - 1
hivemind/moe/client/balancer.py

@@ -135,7 +135,7 @@ class ExpertBalancer:
             if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
                 # response was too slow, choose the next expert and mark this one as slow
                 with self.lock:
-                    new_throughput = 1 / self.throughputs[uid].samples_per_second
+                    new_throughput = 1.5 * max(self.queue, key=itemgetter(1))[1]
                     new_heap_entry = (True, new_throughput, random.random(), uid)
                     heapq.heappush(self.queue, new_heap_entry)
                     self.uid_to_queue[uid] = new_heap_entry

+ 6 - 8
hivemind/optim/performance_ema.py

@@ -54,11 +54,9 @@ class PerformanceEMA:
     def update_threadsafe(self, task_size: float):
         """measure the EMA throughout of the code that runs inside the context"""
         start_timestamp = get_dht_time()
-        try:
-            yield
-        finally:
-            with self.lock:
-                self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
-                # note: we define interval as such to support two distinct scenarios:
-                # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
-                # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls
+        yield
+        with self.lock:
+            self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls