Răsfoiți Sursa

ema with tests

justheuristic 4 ani în urmă
părinte
comite
24e1735a17

+ 1 - 1
hivemind/optim/collaborative.py

@@ -231,7 +231,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=batch_size)
+            self.performance_ema.update(task_size=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:

+ 24 - 9
hivemind/optim/performance_ema.py

@@ -1,4 +1,5 @@
 from contextlib import contextmanager
+from threading import Lock
 
 from hivemind.utils import get_dht_time
 
@@ -9,22 +10,25 @@ class PerformanceEMA:
     :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
     """
 
-    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
         self.alpha, self.eps, self.num_updates = alpha, eps, 0
         self.ema_seconds_per_sample, self.samples_per_second = 0, eps
         self.timestamp = get_dht_time()
-        self.paused = False
+        self.paused = paused
+        self.lock = Lock()
 
-    def update(self, num_processed: int) -> float:
+    def update(self, task_size: int, interval: float) -> float:
         """
-        :param num_processed: how many items were processed since last call
+        :param task_size: how many items were processed since last call
+        :param interval: optionally provide the time delta it took to process this task
         :returns: current estimate of performance (samples per second), but at most
         """
-        assert not self.paused, "PerformanceEMA is currently paused"
-        assert num_processed > 0, f"Can't register processing {num_processed} samples"
-        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
-        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
-        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
+        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
+        assert task_size > 0, f"Can't register processing {task_size} samples"
+        if not self.paused:
+            self.timestamp, old_timestamp = get_dht_time(), self.timestamp
+            interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
+        self.ema_seconds_per_sample = self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
         self.num_updates += 1
         adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
@@ -39,3 +43,14 @@ class PerformanceEMA:
         finally:
             self.timestamp = get_dht_time()
             self.paused = was_paused
+
+    @contextmanager
+    def update_threadsafe(self, task_size: int):
+        """measure the EMA throughout of the code that runs inside the context"""
+        start_timestamp = get_dht_time()
+        yield
+        with self.lock:
+            self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls

+ 29 - 0
tests/test_util_modules.py

@@ -3,12 +3,14 @@ import concurrent.futures
 import multiprocessing as mp
 import random
 import time
+from concurrent.futures import ThreadPoolExecutor
 
 import numpy as np
 import pytest
 import torch
 
 import hivemind
+from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -571,3 +573,30 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
+
+
+@pytest.mark.parametrize("max_workers", [1, 2, 10])
+def test_performance_ema_threadsafe(
+    interval: float = 0.01,
+    max_workers: int = 2,
+    num_updates: int = 100,
+    alpha: float = 0.05,
+    min_scale_power: float = 0.7,
+    max_scale: float = 1.05,
+):
+
+    def run_task(ema):
+        task_size = random.randint(1, 4)
+        with ema.update_threadsafe(task_size):
+            time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
+            return task_size
+
+    with ThreadPoolExecutor(max_workers) as pool:
+        ema = PerformanceEMA(alpha=alpha)
+        start_time = time.perf_counter()
+        futures = [pool.submit(run_task, ema) for i in range(num_updates)]
+        total_size = sum(future.result() for future in futures)
+        end_time = time.perf_counter()
+        target = total_size / (end_time - start_time)
+        assert ema.samples_per_second >= target * max_workers ** (min_scale_power - 1)
+        assert ema.samples_per_second <= target * max_scale