justheuristic 4 lat temu
rodzic
commit
c52e99ecc2
2 zmienionych plików z 79 dodań i 144 usunięć
  1. 50 144
      hivemind/optim/performance_ema.py
  2. 29 0
      tests/test_util_modules.py

+ 50 - 144
hivemind/optim/performance_ema.py

@@ -1,153 +1,59 @@
-from typing import Any, Dict, Optional, Tuple
+from contextlib import contextmanager
+from threading import Lock
 
-import torch
-import torch.nn as nn
-from torch.autograd.function import once_differentiable
+from hivemind.utils import get_dht_time
 
-import hivemind
-from hivemind.moe.client.balancer import ExpertBalancer
-from hivemind.moe.client.expert import DUMMY
-from hivemind.proto import runtime_pb2
-from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
 
-logger = get_logger(__name__)
-
-
-class BalancedRemoteExpert(nn.Module):
+class PerformanceEMA:
     """
-    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
-    ToDo docstring, similar to RemoteMixtureOfExperts
+    A running estimate of performance (operations/sec) using adjusted exponential moving average
+    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
     """
 
-    def __init__(
-        self,
-        *,
-        dht: hivemind.DHT,
-        uid_prefix: str,
-        grid_size: Tuple[int, ...],
-        forward_timeout: Optional[float] = None,
-        backward_timeout: Optional[float] = None,
-        detect_anomalies: bool = False,
-        update_period: float = 30.0,
-        backward_task_size_multiplier: float = 2.5,
-        **kwargs,
-    ):
-        super().__init__()
-        if uid_prefix.endswith(".0."):
-            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
-        assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
-        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
-        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
-        self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
-        self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
-        self._expert_info = None  # expert['info'] from one of experts in the grid
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
+        self.alpha, self.eps, self.num_updates = alpha, eps, 0
+        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
+        self.timestamp = get_dht_time()
+        self.paused = paused
+        self.lock = Lock()
 
-    def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
+    def update(self, task_size: float, interval: float) -> float:
         """
-        Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
-
-        :param args: input tensors that will be passed to each expert after input, batch-first
-        :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
-        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
+        :param task_size: how many items were processed since last call
+        :param interval: optionally provide the time delta it took to process this task
+        :returns: current estimate of performance (samples per second), but at most
         """
-        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
-        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
-
-        if self._expert_info is None:
-            raise NotImplementedError()
-        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-
-        forward_inputs = (args, kwargs)
-
-        if not nested_compare(forward_inputs, self.info["forward_schema"]):
-            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
-        flat_inputs = nested_flatten(forward_inputs)
-        forward_task_size = flat_inputs[0].shape[0]
-
-        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
-        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
-                                                       self.uid,
-                                                       self.expert_balancer,
-                                                       self.info,
-                                                       self.forward_timeout,
-                                                       self.backward_timeout,
-                                                       forward_task_size,
-                                                       forward_task_size * self.backward_task_size_multiplier,
-                                                       *flat_inputs)
-
-        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
-
-    @property
-    def info(self):
-        while self._expert_info is None:
-            try:
-                with self.expert_balancer.use_another_expert(1) as chosen_expert:
-                    self._expert_info = chosen_expert.info
-            except BaseException as e:
-                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
-        return self._expert_info
-
-
-class _BalancedRemoteModuleCall(torch.autograd.Function):
-    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
-
-    @staticmethod
-    def forward(
-            ctx,
-            dummy: torch.Tensor,
-            expert_balancer: ExpertBalancer,
-            info: Dict[str, Any],
-            forward_timeout: float,
-            backward_timeout: float,
-            forward_task_size: float,
-            backward_task_size: float,
-            *inputs: torch.Tensor,
-            ) -> Tuple[torch.Tensor, ...]:
-        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-        # detach to avoid pickling the computation graph
-        ctx.expert_balancer, ctx.info = expert_balancer, info
-        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
-        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
-        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
-        ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
-        while True:
-            try:
-                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
-                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
-                break
-            except BaseException as e:
-                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
-                raise
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
-        return tuple(deserialized_outputs)
-
-    @staticmethod
-    @once_differentiable
-    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
-        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
-            serialize_torch_tensor(tensor, proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
-        while True:
-            try:
-                with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
-                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
-                break
-            except BaseException as e:
-                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
-                raise
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
+        assert task_size > 0, f"Can't register processing {task_size} samples"
+        assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
+        if not self.paused:
+            self.timestamp, old_timestamp = get_dht_time(), self.timestamp
+            interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
+        self.ema_seconds_per_sample = self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        self.num_updates += 1
+        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
+        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
+        return self.samples_per_second
+
+    @contextmanager
+    def pause(self):
+        """While inside this context, EMA will not count the time passed towards the performance estimate"""
+        self.paused, was_paused = True, self.paused
+        try:
+            yield
+        finally:
+            self.timestamp = get_dht_time()
+            self.paused = was_paused
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
+
+    @contextmanager
+    def update_threadsafe(self, task_size: float):
+        """measure the EMA throughout of the code that runs inside the context"""
+        start_timestamp = get_dht_time()
+        yield
+        with self.lock:
+            self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls

+ 29 - 0
tests/test_util_modules.py

@@ -3,6 +3,7 @@ import concurrent.futures
 import multiprocessing as mp
 import random
 import time
+from concurrent.futures import ThreadPoolExecutor
 
 import numpy as np
 import pytest
@@ -14,6 +15,7 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
@@ -521,3 +523,30 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
+
+
+@pytest.mark.parametrize("max_workers", [1, 2, 10])
+def test_performance_ema_threadsafe(
+    max_workers: int = 2,
+    interval: float = 0.01,
+    num_updates: int = 100,
+    alpha: float = 0.05,
+    bias_power: float = 0.7,
+    tolerance: float = 0.05,
+):
+
+    def run_task(ema):
+        task_size = random.randint(1, 4)
+        with ema.update_threadsafe(task_size):
+            time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
+            return task_size
+
+    with ThreadPoolExecutor(max_workers) as pool:
+        ema = PerformanceEMA(alpha=alpha)
+        start_time = time.perf_counter()
+        futures = [pool.submit(run_task, ema) for i in range(num_updates)]
+        total_size = sum(future.result() for future in futures)
+        end_time = time.perf_counter()
+        target = total_size / (end_time - start_time)
+        assert ema.samples_per_second >= (1 - tolerance) * target * max_workers ** (bias_power - 1)
+        assert ema.samples_per_second <= (1 + tolerance) * target