performance_ema.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from contextlib import contextmanager
  2. from hivemind.utils import get_dht_time
  3. class PerformanceEMA:
  4. """
  5. A running estimate of performance (operations/sec) using adjusted exponential moving average
  6. :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
  7. """
  8. def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
  9. self.alpha, self.eps, self.num_updates = alpha, eps, 0
  10. self.ema_seconds_per_sample, self.samples_per_second = 0, eps
  11. self.timestamp = get_dht_time()
  12. self.paused = False
  13. def update(self, num_processed: int) -> float:
  14. """
  15. :param num_processed: how many items were processed since last call
  16. :returns: current estimate of performance (samples per second), but at most
  17. """
  18. assert not self.paused, "PerformanceEMA is currently paused"
  19. assert num_processed > 0, f"Can't register processing {num_processed} samples"
  20. self.timestamp, old_timestamp = get_dht_time(), self.timestamp
  21. seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
  22. self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
  23. self.num_updates += 1
  24. adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
  25. self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
  26. return self.samples_per_second
  27. @contextmanager
  28. def pause(self):
  29. """While inside this context, EMA will not count the time passed towards the performance estimate"""
  30. self.paused, was_paused = True, self.paused
  31. try:
  32. yield
  33. finally:
  34. self.timestamp = get_dht_time()
  35. self.paused = was_paused