Sfoglia il codice sorgente

backport PerformanceEMA from server_side_averaging (#397)

This PR adds thread-safe performance measurement in PerformanceEMA that was previously introduced in #365

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 anni fa
parent
commit
025e095d55

+ 1 - 1
hivemind/optim/collaborative.py

@@ -263,7 +263,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
             self.local_updates_accumulated += 1
             self.local_updates_accumulated += 1
-            self.performance_ema.update(num_processed=batch_size)
+            self.performance_ema.update(task_size=batch_size)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
         if not self.collaboration_state.ready_for_step:
         if not self.collaboration_state.ready_for_step:

+ 38 - 13
hivemind/optim/performance_ema.py

@@ -1,6 +1,7 @@
+import time
 from contextlib import contextmanager
 from contextlib import contextmanager
-
-from hivemind.utils import get_dht_time
+from threading import Lock
+from typing import Optional
 
 
 
 
 class PerformanceEMA:
 class PerformanceEMA:
@@ -9,22 +10,28 @@ class PerformanceEMA:
     :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
     :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
     """
     """
 
 
-    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
         self.alpha, self.eps, self.num_updates = alpha, eps, 0
         self.alpha, self.eps, self.num_updates = alpha, eps, 0
         self.ema_seconds_per_sample, self.samples_per_second = 0, eps
         self.ema_seconds_per_sample, self.samples_per_second = 0, eps
-        self.timestamp = get_dht_time()
-        self.paused = False
+        self.timestamp = time.perf_counter()
+        self.paused = paused
+        self.lock = Lock()
 
 
-    def update(self, num_processed: int) -> float:
+    def update(self, task_size: float, interval: Optional[float] = None) -> float:
         """
         """
-        :param num_processed: how many items were processed since last call
+        :param task_size: how many items were processed since last call
+        :param interval: optionally provide the time delta it took to process this task
         :returns: current estimate of performance (samples per second), but at most
         :returns: current estimate of performance (samples per second), but at most
         """
         """
-        assert not self.paused, "PerformanceEMA is currently paused"
-        assert num_processed > 0, f"Can't register processing {num_processed} samples"
-        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
-        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
-        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
+        assert task_size > 0, f"Can't register processing {task_size} samples"
+        if not self.paused:
+            self.timestamp, old_timestamp = time.perf_counter(), self.timestamp
+            interval = interval if interval is not None else self.timestamp - old_timestamp
+        else:
+            assert interval is not None, "If PerformanceEMA is paused, please specify the time interval"
+        self.ema_seconds_per_sample = (
+            self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        )
         self.num_updates += 1
         self.num_updates += 1
         adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
         adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
@@ -37,5 +44,23 @@ class PerformanceEMA:
         try:
         try:
             yield
             yield
         finally:
         finally:
-            self.timestamp = get_dht_time()
+            self.timestamp = time.perf_counter()
             self.paused = was_paused
             self.paused = was_paused
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
+
+    @contextmanager
+    def update_threadsafe(self, task_size: float):
+        """
+        Update the EMA throughput of a code that runs inside the context manager, supports multiple concurrent threads.
+
+        :param task_size: how many items were processed since last call
+        """
+        start_timestamp = time.perf_counter()
+        yield
+        with self.lock:
+            self.update(task_size, interval=time.perf_counter() - max(start_timestamp, self.timestamp))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls

+ 28 - 0
tests/test_util_modules.py

@@ -3,6 +3,7 @@ import concurrent.futures
 import multiprocessing as mp
 import multiprocessing as mp
 import random
 import random
 import time
 import time
+from concurrent.futures import ThreadPoolExecutor
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -10,6 +11,7 @@ import torch
 
 
 import hivemind
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -549,3 +551,29 @@ def test_batch_tensor_descriptor_msgpack():
         and tensor_descr.pin_memory == tensor_descr.pin_memory
         and tensor_descr.pin_memory == tensor_descr.pin_memory
         and tensor_descr.compression == tensor_descr.compression
         and tensor_descr.compression == tensor_descr.compression
     )
     )
+
+
+@pytest.mark.parametrize("max_workers", [1, 2, 10])
+def test_performance_ema_threadsafe(
+    max_workers: int,
+    interval: float = 0.01,
+    num_updates: int = 100,
+    alpha: float = 0.05,
+    bias_power: float = 0.7,
+    tolerance: float = 0.05,
+):
+    def run_task(ema):
+        task_size = random.randint(1, 4)
+        with ema.update_threadsafe(task_size):
+            time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
+            return task_size
+
+    with ThreadPoolExecutor(max_workers) as pool:
+        ema = PerformanceEMA(alpha=alpha)
+        start_time = time.perf_counter()
+        futures = [pool.submit(run_task, ema) for i in range(num_updates)]
+        total_size = sum(future.result() for future in futures)
+        end_time = time.perf_counter()
+        target = total_size / (end_time - start_time)
+        assert ema.samples_per_second >= (1 - tolerance) * target * max_workers ** (bias_power - 1)
+        assert ema.samples_per_second <= (1 + tolerance) * target