Browse Source

port to newer version of hivemind

justheuristic 4 năm trước cách đây
mục cha
commit
1649274a76

+ 159 - 0
hivemind/moe/client/balanced_expert.py

@@ -0,0 +1,159 @@
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import once_differentiable
+
+import hivemind
+from hivemind.moe.client.balancer import ExpertBalancer
+from hivemind.moe.client.expert import DUMMY
+from hivemind.proto import runtime_pb2
+from hivemind.utils import (
+    deserialize_torch_tensor,
+    get_logger,
+    nested_compare,
+    nested_flatten,
+    nested_pack,
+    serialize_torch_tensor,
+)
+
+logger = get_logger(__name__)
+
+
+class BalancedRemoteExpert(nn.Module):
+    """
+    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
+    ToDo docstring, similar to RemoteMixtureOfExperts
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        uid_prefix: str,
+        grid_size: Tuple[int, ...],
+        forward_timeout: Optional[float] = None,
+        backward_timeout: Optional[float] = None,
+        detect_anomalies: bool = False,
+        update_period: float = 30.0,
+        backward_task_size_multiplier: float = 2.5,
+        **kwargs,
+    ):
+        super().__init__()
+        if uid_prefix.endswith(".0."):
+            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
+        assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
+        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
+        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
+        self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
+        self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
+        self._expert_info = None  # expert['info'] from one of experts in the grid
+
+    def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
+        """
+        Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
+
+        :param args: input tensors that will be passed to each expert after input, batch-first
+        :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
+        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
+        """
+        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
+        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
+
+        if self._expert_info is None:
+            raise NotImplementedError()
+        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+
+        forward_inputs = (args, kwargs)
+
+        if not nested_compare(forward_inputs, self.info["forward_schema"]):
+            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+        flat_inputs = nested_flatten(forward_inputs)
+        forward_task_size = flat_inputs[0].shape[0]
+
+        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
+                                                       self.uid,
+                                                       self.expert_balancer,
+                                                       self.info,
+                                                       self.forward_timeout,
+                                                       self.backward_timeout,
+                                                       forward_task_size,
+                                                       forward_task_size * self.backward_task_size_multiplier,
+                                                       *flat_inputs)
+
+        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
+
+    @property
+    def info(self):
+        while self._expert_info is None:
+            try:
+                with self.expert_balancer.use_another_expert(1) as chosen_expert:
+                    self._expert_info = chosen_expert.info
+            except BaseException as e:
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
+        return self._expert_info
+
+
+class _BalancedRemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
+
+    @staticmethod
+    def forward(
+            ctx,
+            dummy: torch.Tensor,
+            expert_balancer: ExpertBalancer,
+            info: Dict[str, Any],
+            forward_timeout: float,
+            backward_timeout: float,
+            forward_task_size: float,
+            backward_task_size: float,
+            *inputs: torch.Tensor,
+            ) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        # detach to avoid pickling the computation graph
+        ctx.expert_balancer, ctx.info = expert_balancer, info
+        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
+        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
+        ctx.save_for_backward(*inputs)
+
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
+        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
+                raise
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        return tuple(deserialized_outputs)
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        ]
+        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
+                raise
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 128 - 0
hivemind/moe/client/balancer.py

@@ -0,0 +1,128 @@
+import heapq
+import random
+import threading
+from contextlib import contextmanager
+from typing import Dict, List, Tuple
+
+from hivemind import Endpoint, RemoteExpert, TimedStorage
+from hivemind.dht import DHT
+from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
+from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class ExpertBalancer:
+    def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+                 **kwargs):
+        self.dht, self.key = dht, key
+        self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
+        self.experts = TimedStorage[ExpertUID, Endpoint]()
+        self.blacklist = TimedStorage[ExpertUID, type(None)]()
+        self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
+        self.queue: List[Tuple[float, float, ExpertUID]] = []
+        self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
+        self.lock = threading.Lock()
+        self.is_alive = threading.Event()
+        self.is_alive.set()
+        self.update_trigger, self.update_finished = threading.Event(), threading.Event()
+        self.update_period, self.last_update = update_period, get_dht_time()
+        self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
+        self.update_thread.start()
+
+    def update_experts_in_background(self):
+        while self.is_alive.is_set():
+            time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
+            try:
+                self.update_trigger.wait(timeout=time_to_next_update)
+                # update triggered by main thread
+            except TimeoutError:
+                pass  # update triggered by refresh_period
+
+            self.update_trigger.clear()
+            response = self.dht.get(self.key, latest=True)
+            if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
+                for index, expert_info in response.value.items():
+                    try:
+                        (uid, endpoint), expiration_time = expert_info
+
+                        maybe_banned = self.blacklist.get(uid)
+                        if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
+                            self._add_expert(uid, endpoint, expiration_time)
+                        else:
+                            logger.debug(f"Not adding expert {uid} (blacklisted).")
+                    except Exception as e:
+                        logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
+            else:
+                logger.warning(f"Could not refresh experts, dht info key contains {response}, "
+                               f"will retry in {time_to_next_update}s")
+            if len(self.queue) == 0:
+                logger.warning("Update routine finished, but still no experts available.")
+
+            self.last_update = get_dht_time()
+            self.update_finished.set()
+
+    def _add_expert(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration):
+        with self.lock:
+            self.experts.store(uid, endpoint, expiration_time)
+            if uid not in self.uid_to_queue:
+                logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
+                self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
+                base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
+                heap_entry = (base_load, random.random(), uid)
+                heapq.heappush(self.queue, heap_entry)
+                self.uid_to_queue[uid] = heap_entry
+            else:
+                logger.debug(f"Refreshing existing expert: {uid}, new expiration time = {expiration_time:.3f}.")
+
+    def _ban_expert(self, uid: ExpertUID):
+        with self.lock:
+            maybe_expert = self.experts.get(uid)
+            expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
+            self.blacklist.store(uid, None, expiration_time)
+            self.uid_to_queue.pop(uid, None)
+            self.throughputs.pop(uid, None)
+            del self.experts[uid]
+            logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
+
+    @contextmanager
+    def use_another_expert(self, task_size: float) -> RemoteExpert:
+        while True:
+            if len(self.queue) == 0:
+                self.update_finished.clear()
+                self.update_trigger.set()
+                self.update_finished.wait()
+                continue
+
+            with self.lock:
+                current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
+                maybe_endpoint = self.experts.get(uid)
+                if maybe_endpoint is None:
+                    # remove expired expert from queue
+                    self.uid_to_queue.pop(uid, None)
+                    self.throughputs.pop(uid, None)
+                if self.uid_to_queue.get(uid) != heap_entry:
+                    continue  # skip uids that are banned or expired
+
+                if self.throughputs[uid].num_updates != 0:
+                    expected_time_taken = task_size / self.throughputs[uid].samples_per_second
+                else:
+                    expected_time_taken = self.initial_throughput * task_size
+                new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
+                heapq.heappush(self.queue, new_heap_entry)
+                self.uid_to_queue[uid] = new_heap_entry
+                break
+        try:
+            with self.throughputs[uid].update_threadsafe(task_size):
+                logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
+                yield RemoteExpert(uid, maybe_endpoint.value)
+        except BaseException:
+            self._ban_expert(uid)
+            raise
+
+    def shutdown(self):
+        self.is_alive.clear()
+        self.update_finished.clear()
+        self.update_trigger.set()
+        self.update_finished.wait()

+ 1 - 1
hivemind/optim/collaborative.py

@@ -231,7 +231,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=batch_size)
+            self.performance_ema.update(task_size=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:

+ 144 - 32
hivemind/optim/performance_ema.py

@@ -1,41 +1,153 @@
-from contextlib import contextmanager
+from typing import Any, Dict, Optional, Tuple
 
-from hivemind.utils import get_dht_time
+import torch
+import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
+import hivemind
+from hivemind.moe.client.balancer import ExpertBalancer
+from hivemind.moe.client.expert import DUMMY
+from hivemind.proto import runtime_pb2
+from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
 
-class PerformanceEMA:
+logger = get_logger(__name__)
+
+
+class BalancedRemoteExpert(nn.Module):
     """
-    A running estimate of performance (operations/sec) using adjusted exponential moving average
-    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
+    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
+    ToDo docstring, similar to RemoteMixtureOfExperts
     """
 
-    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
-        self.alpha, self.eps, self.num_updates = alpha, eps, 0
-        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
-        self.timestamp = get_dht_time()
-        self.paused = False
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        uid_prefix: str,
+        grid_size: Tuple[int, ...],
+        forward_timeout: Optional[float] = None,
+        backward_timeout: Optional[float] = None,
+        detect_anomalies: bool = False,
+        update_period: float = 30.0,
+        backward_task_size_multiplier: float = 2.5,
+        **kwargs,
+    ):
+        super().__init__()
+        if uid_prefix.endswith(".0."):
+            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
+        assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
+        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
+        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
+        self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
+        self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
+        self._expert_info = None  # expert['info'] from one of experts in the grid
 
-    def update(self, num_processed: int) -> float:
+    def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
-        :param num_processed: how many items were processed since last call
-        :returns: current estimate of performance (samples per second), but at most
+        Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
+
+        :param args: input tensors that will be passed to each expert after input, batch-first
+        :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
+        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
         """
-        assert not self.paused, "PerformanceEMA is currently paused"
-        assert num_processed > 0, f"Can't register processing {num_processed} samples"
-        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
-        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
-        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
-        self.num_updates += 1
-        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
-        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
-        return self.samples_per_second
-
-    @contextmanager
-    def pause(self):
-        """While inside this context, EMA will not count the time passed towards the performance estimate"""
-        self.paused, was_paused = True, self.paused
-        try:
-            yield
-        finally:
-            self.timestamp = get_dht_time()
-            self.paused = was_paused
+        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
+        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
+
+        if self._expert_info is None:
+            raise NotImplementedError()
+        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+
+        forward_inputs = (args, kwargs)
+
+        if not nested_compare(forward_inputs, self.info["forward_schema"]):
+            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+        flat_inputs = nested_flatten(forward_inputs)
+        forward_task_size = flat_inputs[0].shape[0]
+
+        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
+                                                       self.uid,
+                                                       self.expert_balancer,
+                                                       self.info,
+                                                       self.forward_timeout,
+                                                       self.backward_timeout,
+                                                       forward_task_size,
+                                                       forward_task_size * self.backward_task_size_multiplier,
+                                                       *flat_inputs)
+
+        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
+
+    @property
+    def info(self):
+        while self._expert_info is None:
+            try:
+                with self.expert_balancer.use_another_expert(1) as chosen_expert:
+                    self._expert_info = chosen_expert.info
+            except BaseException as e:
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
+        return self._expert_info
+
+
+class _BalancedRemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
+
+    @staticmethod
+    def forward(
+            ctx,
+            dummy: torch.Tensor,
+            expert_balancer: ExpertBalancer,
+            info: Dict[str, Any],
+            forward_timeout: float,
+            backward_timeout: float,
+            forward_task_size: float,
+            backward_task_size: float,
+            *inputs: torch.Tensor,
+            ) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        # detach to avoid pickling the computation graph
+        ctx.expert_balancer, ctx.info = expert_balancer, info
+        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
+        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
+        ctx.save_for_backward(*inputs)
+
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
+        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
+                raise
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        return tuple(deserialized_outputs)
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        ]
+        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
+        while True:
+            try:
+                with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
+                raise
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 7 - 3
hivemind/utils/timed_storage.py

@@ -113,6 +113,12 @@ class TimedStorage(Generic[KeyType, ValueType]):
         self.key_to_heap.clear()
         self.expiration_heap.clear()
 
+    def discard(self, key: KeyType):
+        """If storage contains key, drop the corresponding item, otherwise do nothing"""
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
     def __contains__(self, key: KeyType):
         self._remove_outdated()
         return key in self.data
@@ -122,9 +128,7 @@ class TimedStorage(Generic[KeyType, ValueType]):
         return len(self.data)
 
     def __delitem__(self, key: KeyType):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+        self.discard(key)
 
     def __bool__(self):
         return bool(self.data)