|
@@ -1,153 +1,59 @@
|
|
|
-from typing import Any, Dict, Optional, Tuple
|
|
|
+from contextlib import contextmanager
|
|
|
+from threading import Lock
|
|
|
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
-from torch.autograd.function import once_differentiable
|
|
|
+from hivemind.utils import get_dht_time
|
|
|
|
|
|
-import hivemind
|
|
|
-from hivemind.moe.client.balancer import ExpertBalancer
|
|
|
-from hivemind.moe.client.expert import DUMMY
|
|
|
-from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
-from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
|
|
|
|
|
|
-logger = get_logger(__name__)
|
|
|
-
|
|
|
-
|
|
|
-class BalancedRemoteExpert(nn.Module):
|
|
|
+class PerformanceEMA:
|
|
|
"""
|
|
|
- A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
|
|
|
- ToDo docstring, similar to RemoteMixtureOfExperts
|
|
|
+ A running estimate of performance (operations/sec) using adjusted exponential moving average
|
|
|
+ :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
|
|
|
"""
|
|
|
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- *,
|
|
|
- dht: hivemind.DHT,
|
|
|
- uid_prefix: str,
|
|
|
- grid_size: Tuple[int, ...],
|
|
|
- forward_timeout: Optional[float] = None,
|
|
|
- backward_timeout: Optional[float] = None,
|
|
|
- detect_anomalies: bool = False,
|
|
|
- update_period: float = 30.0,
|
|
|
- backward_task_size_multiplier: float = 2.5,
|
|
|
- **kwargs,
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
- if uid_prefix.endswith(".0."):
|
|
|
- logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
|
|
|
- assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
|
|
|
- self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
|
- self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
- self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
|
|
|
- self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
|
|
|
- self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
+ def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
|
|
|
+ self.alpha, self.eps, self.num_updates = alpha, eps, 0
|
|
|
+ self.ema_seconds_per_sample, self.samples_per_second = 0, eps
|
|
|
+ self.timestamp = get_dht_time()
|
|
|
+ self.paused = paused
|
|
|
+ self.lock = Lock()
|
|
|
|
|
|
- def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
+ def update(self, task_size: float, interval: float) -> float:
|
|
|
"""
|
|
|
- Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
|
|
|
-
|
|
|
- :param args: input tensors that will be passed to each expert after input, batch-first
|
|
|
- :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
|
|
|
- :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
|
|
|
+ :param task_size: how many items were processed since last call
|
|
|
+ :param interval: optionally provide the time delta it took to process this task
|
|
|
+ :returns: current estimate of performance (samples per second), but at most
|
|
|
"""
|
|
|
- assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
|
|
|
- kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
|
|
|
-
|
|
|
- if self._expert_info is None:
|
|
|
- raise NotImplementedError()
|
|
|
- # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
|
|
-
|
|
|
- forward_inputs = (args, kwargs)
|
|
|
-
|
|
|
- if not nested_compare(forward_inputs, self.info["forward_schema"]):
|
|
|
- raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
-
|
|
|
- flat_inputs = nested_flatten(forward_inputs)
|
|
|
- forward_task_size = flat_inputs[0].shape[0]
|
|
|
-
|
|
|
- # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
|
- flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
|
|
|
- self.uid,
|
|
|
- self.expert_balancer,
|
|
|
- self.info,
|
|
|
- self.forward_timeout,
|
|
|
- self.backward_timeout,
|
|
|
- forward_task_size,
|
|
|
- forward_task_size * self.backward_task_size_multiplier,
|
|
|
- *flat_inputs)
|
|
|
-
|
|
|
- return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
|
|
|
-
|
|
|
- @property
|
|
|
- def info(self):
|
|
|
- while self._expert_info is None:
|
|
|
- try:
|
|
|
- with self.expert_balancer.use_another_expert(1) as chosen_expert:
|
|
|
- self._expert_info = chosen_expert.info
|
|
|
- except BaseException as e:
|
|
|
- logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
|
|
|
- return self._expert_info
|
|
|
-
|
|
|
-
|
|
|
-class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
- """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def forward(
|
|
|
- ctx,
|
|
|
- dummy: torch.Tensor,
|
|
|
- expert_balancer: ExpertBalancer,
|
|
|
- info: Dict[str, Any],
|
|
|
- forward_timeout: float,
|
|
|
- backward_timeout: float,
|
|
|
- forward_task_size: float,
|
|
|
- backward_task_size: float,
|
|
|
- *inputs: torch.Tensor,
|
|
|
- ) -> Tuple[torch.Tensor, ...]:
|
|
|
- # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
- # detach to avoid pickling the computation graph
|
|
|
- ctx.expert_balancer, ctx.info = expert_balancer, info
|
|
|
- ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
|
|
|
- ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
|
|
|
- inputs = tuple(tensor.cpu().detach() for tensor in inputs)
|
|
|
- ctx.save_for_backward(*inputs)
|
|
|
-
|
|
|
- serialized_tensors = [
|
|
|
- serialize_torch_tensor(inp, proto.compression)
|
|
|
- for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
- ]
|
|
|
- forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
- while True:
|
|
|
- try:
|
|
|
- with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
- outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
- break
|
|
|
- except BaseException as e:
|
|
|
- logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
|
|
|
- raise
|
|
|
-
|
|
|
- deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
- return tuple(deserialized_outputs)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- @once_differentiable
|
|
|
- def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
- grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
- inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
- backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
- serialized_tensors = [
|
|
|
- serialize_torch_tensor(tensor, proto.compression)
|
|
|
- for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
- ]
|
|
|
- backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
- while True:
|
|
|
- try:
|
|
|
- with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
|
|
|
- grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
|
- break
|
|
|
- except BaseException as e:
|
|
|
- logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
|
|
|
- raise
|
|
|
- deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
- return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|
|
|
+ assert task_size > 0, f"Can't register processing {task_size} samples"
|
|
|
+ assert not self.paused or interval is not None, "If PerformanceEMA is paused, one must specify time interval"
|
|
|
+ if not self.paused:
|
|
|
+ self.timestamp, old_timestamp = get_dht_time(), self.timestamp
|
|
|
+ interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
|
|
|
+ self.ema_seconds_per_sample = self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
|
|
|
+ self.num_updates += 1
|
|
|
+ adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
|
|
|
+ self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
|
|
|
+ return self.samples_per_second
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def pause(self):
|
|
|
+ """While inside this context, EMA will not count the time passed towards the performance estimate"""
|
|
|
+ self.paused, was_paused = True, self.paused
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ self.timestamp = get_dht_time()
|
|
|
+ self.paused = was_paused
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def update_threadsafe(self, task_size: float):
|
|
|
+ """measure the EMA throughout of the code that runs inside the context"""
|
|
|
+ start_timestamp = get_dht_time()
|
|
|
+ yield
|
|
|
+ with self.lock:
|
|
|
+ self.update(task_size, interval=max(0.0, get_dht_time() - max(start_timestamp, self.timestamp)))
|
|
|
+ # note: we define interval as such to support two distinct scenarios:
|
|
|
+ # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
|
|
|
+ # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls
|