浏览代码

Add Switch Transformers-like RemoteMixtureOfExperts (#228)

* Add RemoteSwitchMixtureOfExperts

* Add load balancing and test for training

* Make grid_size non-optional

* Support passing *args to batch_size

* Reformat tests/custom_networks.py

* Reformat test_custom_expert.py

* Add DelayedNopExpert

* Add allow_zero_outputs to RemoteMixtureOfExperts

* Fix exception handling in DHT, log tracebacks in async_run_coroutine

* Generate BatchTensorDescriptor from dummy inputs on server start
This ensures that outputs_schema contains correct attributes such as dtype

* Handle exceptions on start of background_server

* Add test_no_experts
Max Ryabinin 4 年之前
父节点
当前提交
62652e1717

+ 4 - 0
docs/modules/client.rst

@@ -18,6 +18,10 @@
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
 
 
+.. autoclass:: RemoteSwitchMixtureOfExperts
+   :members:
+   :member-order: bysource
+
 .. autoclass:: DecentralizedAverager
 .. autoclass:: DecentralizedAverager
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource

+ 1 - 0
hivemind/client/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.moe import RemoteMixtureOfExperts
+from hivemind.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging.training import TrainingAverager
 from hivemind.client.averaging.training import TrainingAverager

+ 3 - 2
hivemind/client/beam_search.py

@@ -63,7 +63,7 @@ class MoEBeamSearcher:
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
     """
     """
 
 
-    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Optional[Tuple[int, ...]] = None,
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Tuple[int, ...],
                  num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
                  num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
         if not uid_prefix.endswith(UID_DELIMITER):
         if not uid_prefix.endswith(UID_DELIMITER):
             uid_prefix += UID_DELIMITER
             uid_prefix += UID_DELIMITER
@@ -71,6 +71,7 @@ class MoEBeamSearcher:
         assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
         assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
         self.dht = dht
         self.dht = dht
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
+        self.total_grid_size = sum(grid_size)
         self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
         self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
 
 
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
@@ -174,7 +175,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         """
-        assert (not self.grid_size or len(grid_scores) == len(self.grid_size)) and beam_size > 0
+        assert len(grid_scores) == len(self.grid_size) and beam_size > 0
         return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
         return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
                                               grid_scores=list(grid_scores), negative_caching=self.negative_caching,
                                               grid_scores=list(grid_scores), negative_caching=self.negative_caching,
                                               num_workers=self.num_workers), return_future)
                                               num_workers=self.num_workers), return_future)

+ 56 - 38
hivemind/client/moe.py

@@ -14,7 +14,7 @@ from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_uid import UID_DELIMITER
 from hivemind.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_pack, nested_flatten, nested_map
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -23,7 +23,7 @@ logger = get_logger(__name__)
 
 
 class RemoteMixtureOfExperts(nn.Module):
 class RemoteMixtureOfExperts(nn.Module):
     """
     """
-    A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
+    A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
     Natively supports pytorch autograd.
     Natively supports pytorch autograd.
 
 
     :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
     :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
@@ -38,14 +38,15 @@ class RemoteMixtureOfExperts(nn.Module):
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
-    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
      Any expert that didn't manage to return output after that delay is considered unavailable
      Any expert that didn't manage to return output after that delay is considered unavailable
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
+    :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
     """
     """
 
 
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
-                 **dht_kwargs):
+                 allow_zero_outputs: bool = False, **dht_kwargs):
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
         self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
@@ -53,8 +54,10 @@ class RemoteMixtureOfExperts(nn.Module):
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.timeout_after_k_min = timeout_after_k_min
         self.detect_anomalies = detect_anomalies
         self.detect_anomalies = detect_anomalies
+        self.allow_zero_outputs = allow_zero_outputs
 
 
-        self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
+        # jointly predict logits for all grid dimensions
+        self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
         self._expert_info = None  # expert['info'] from one of experts in the grid
         self._expert_info = None  # expert['info'] from one of experts in the grid
 
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
@@ -87,7 +90,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
+            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
+            *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -97,6 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
         averaged_outputs_flat = [
         averaged_outputs_flat = [
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+
         return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
         return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
 
 
     def compute_expert_scores(
     def compute_expert_scores(
@@ -152,13 +157,14 @@ class _RemoteCallMany(torch.autograd.Function):
     one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
     one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
 
 
     Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
     Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
-          experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
+          experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
     """
     """
 
 
     @classmethod
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any],
+                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
 
 
@@ -181,32 +187,42 @@ class _RemoteCallMany(torch.autograd.Function):
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
 
 
-        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
+        responded_inds, alive_flat_outputs = cls._collect_responses(
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
-        if len(alive_grid_indices) == 0:
-            raise TimeoutError("Forward pass: no alive experts responded within timeout.")
+        if len(responded_inds) < k_min:
+            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
+
+        if not isinstance(info['outputs_schema'], tuple):
+            outputs_schema = (info['outputs_schema'],)
+        else:
+            outputs_schema = info['outputs_schema']
+        outputs = nested_map(
+            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            outputs_schema)
 
 
         # assemble responses
         # assemble responses
-        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
-        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
-        mask[alive_ii, alive_jj] = True
+        if len(responded_inds) > 0 or allow_zero_outputs:
+            batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
+                                          list(zip(*responded_inds)) or ([], []))
 
 
-        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
-        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+            alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
+            # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
 
 
-        outputs = []
-        for response_stacked in alive_flat_outputs_stacked:
-            output = torch.zeros(
-                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
-                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
-            output[alive_ii, alive_jj] = response_stacked
-            outputs.append(output.to(flat_inputs[0].device))
+            for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
+                output[batch_inds, expert_inds] = response_stacked.to(output.device)
+
+        else:
+            raise RuntimeError('Forward pass: 0 experts responded within timeout and allow_zero_outputs is False')
+
+        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
+        mask[batch_inds, expert_inds] = True
 
 
         # save individual outputs for backward pass
         # save individual outputs for backward pass
-        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
+        ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
         ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
         ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
                                   detect_anomalies)
                                   detect_anomalies)
-        return (mask,) + tuple(outputs)
+
+        return (mask,) + outputs
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
@@ -235,35 +251,37 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
         for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                     inputs_per_expert, grad_outputs_per_expert):
                                                     inputs_per_expert, grad_outputs_per_expert):
             expert = expert_per_sample[i.item()][j.item()]
             expert = expert_per_sample[i.item()][j.item()]
-            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+            stub = _get_expert_stub(expert.endpoint)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
             tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
                                   for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
                                   for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
             new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
             new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)
 
 
-        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
+        survivor_inds, survivor_grad_inputs = cls._collect_responses(
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
-        if len(backward_survivor_indices) == 0:
-            raise TimeoutError("Backward pass: no alive experts responded within timeout.")
+        if len(survivor_inds) < backward_k_min:
+            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
 
 
         # assemble responses
         # assemble responses
-        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
+        batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, dtype=torch.long),
+                                      list(zip(*survivor_inds)) or ([], []))
 
 
         survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
         survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
 
-        grad_inputs = []
-        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
+        grad_inputs = nested_map(
+            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            list(nested_flatten(info['forward_schema'])))
+
+        for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
             grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
             grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
-                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
+                (num_samples, max_experts, *grad_input.shape[1:]),
                 device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
                 device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
-            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
-
-            # sum gradients from each expert
-            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
+            grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
+            grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
 
-        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
+        return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
 
 
     @staticmethod
     @staticmethod
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,

+ 175 - 0
hivemind/client/switch_moe.py

@@ -0,0 +1,175 @@
+from __future__ import annotations
+
+from typing import Tuple, List
+
+import grpc
+import torch
+
+from hivemind.client.expert import RemoteExpert, DUMMY
+from hivemind.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
+    """
+    A module implementing Switch Transformers [1] Mixture-of-Experts inference with remote experts.
+
+    [1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
+     William Fedus, Barret Zoph, Noam Shazeer. https://arxiv.org/abs/2101.03961
+
+    :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
+     forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
+     the missing experts
+
+    :param in_features: common input size for experts and gating function
+    :param grid_size: dimensions that form expert uid (see below)
+    :param uid_prefix: common prefix for all expert uids (must end with '.')
+    :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
+    :param dht: a DHT instance used to search for best experts
+    :param k_best: average this many highest-scoring experts to compute activations
+    :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
+    :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
+     Any expert that didn't manage to return output after that delay is considered unavailable
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
+    :param allow_zero_outputs: whether to return just the input if no experts respond on forward pass
+    """
+
+    def __init__(self, *, grid_size: Tuple[int, ...], utilization_alpha: float = 0.9, grid_dropout: float = 1.0,
+                 jitter_eps: float = 1e-2, k_best=1, k_min=0, backward_k_min=0, allow_zero_outputs=True, **kwargs):
+        super().__init__(grid_size=grid_size, k_best=k_best, k_min=k_min, backward_k_min=backward_k_min,
+                         allow_zero_outputs=allow_zero_outputs, **kwargs)
+
+        initial_utilization = torch.cat(
+            [torch.tensor([1 / dim_size for _ in range(dim_size)], dtype=torch.float)
+             for dim_size in grid_size],
+        )
+        self.register_buffer('grid_utilization', initial_utilization)
+        self.utilization_alpha = utilization_alpha
+        self.grid_dropout = grid_dropout
+        self.jitter_eps = jitter_eps
+
+    def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
+        if input.ndim != 2:
+            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
+        else:
+            input_for_gating = input
+
+        # Multiplicative jitter for regularized routing
+        jitter_noise = torch.empty_like(input_for_gating).uniform_(1 - self.jitter_eps, 1 + self.jitter_eps)
+        input_for_gating *= jitter_noise
+
+        # Compute scores, find most appropriate experts with beam search
+        grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
+
+        grid_dropout_masks = (
+            (torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
+             < self.grid_dropout) for dim_size in self.beam_search.grid_size
+        )
+        grid_scores_dropout = [torch.where(dropout_mask, grid_score,
+                                           torch.full((1,), float('-inf'), device=grid_score.device,
+                                                      dtype=grid_score.dtype))
+                               for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)]
+
+        grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
+        chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
+            [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best)
+
+        if self._expert_info is None:
+            try:
+                self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
+            except grpc.RpcError as e:
+                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
+
+        expert_mask, *expert_outputs = _RemoteCallMany.apply(
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
+            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
+            *nested_flatten(((input, *args), kwargs)))
+        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
+
+        batch_utilization = self._compute_batch_utilization(chosen_experts, expert_mask)
+        self.grid_utilization = \
+            self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization
+
+        # compute expert probabilities as product across grid dimensions
+        expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
+        masked_logits = torch.full((1,), float('-inf'), device=expert_probs.device, dtype=expert_probs.dtype)
+        expert_probs = torch.where(expert_mask, expert_probs, masked_logits)
+
+        # multiply outputs by expert probabilities
+        averaged_outputs_flat = [
+            (expert_probs[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
+            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+
+        packed_outputs = nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
+
+        # Load balancing loss: multiply fractions of probability mass and fractions of routed examples
+        # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
+        balancing_loss = torch.stack([torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
+                                      for dim_softmax, dim_utilization, dim_size in
+                                      zip(grid_softmax, self.grid_utilization, self.beam_search.grid_size)]).sum()
+
+        # residual connection
+        if isinstance(packed_outputs, torch.Tensor):
+            packed_outputs = packed_outputs + input
+        else:
+            packed_outputs[0] = packed_outputs[0] + input
+
+        return packed_outputs, balancing_loss
+
+    @torch.no_grad()
+    def _compute_batch_utilization(self, batch_experts, expert_mask):
+        batch_utilization = [torch.zeros((dim_size,), dtype=self.grid_utilization.dtype,
+                                         device=self.grid_utilization.device)
+                             for dim_size in self.beam_search.grid_size]
+
+        # out of chosen_experts, select those for which expert_mask is True
+        for (sample_idx, expert_idx) in expert_mask.nonzero().numpy():
+            expert = batch_experts[sample_idx][expert_idx]
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
+
+            for dim_index, dim_utilization in zip(expert_indices, batch_utilization):
+                dim_utilization[dim_index] += 1
+
+        return torch.cat([
+            torch.nn.functional.normalize(dim_utilization, p=1, dim=0)
+            for dim_utilization in batch_utilization
+        ])
+
+    def compute_expert_scores(
+            self, grid_probs: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        """
+        Compute scores for each expert by multiplying grid probabilities, autograd-friendly
+        :param grid_probs: list of torch tensors, i-th tensor contains scores for i-th grid dimension
+        :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
+        :returns: a tensor of scores, float32[batch_size, k]
+        :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
+        """
+        expert_counts = list(map(len, batch_experts))
+        batch_size = len(batch_experts)
+        max_num_experts = max(expert_counts)
+        total_num_experts = sum(expert_counts)
+        expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]
+        flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
+        flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
+        flat_experts = [expert for row in batch_experts for expert in row]
+
+        grid_indices = torch.zeros([len(flat_experts), len(grid_probs)], dtype=torch.int64)
+        for i, expert in enumerate(flat_experts):
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
+            grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
+
+        scores_per_dim = [
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
+        flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
+
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
+        scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
+        return scores

+ 3 - 2
hivemind/dht/__init__.py

@@ -91,7 +91,7 @@ class DHT(mp.Process):
         """
         """
         self.start()
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shut down a running dht process """
         """ Shut down a running dht process """
@@ -186,6 +186,7 @@ class DHT(mp.Process):
             else:
             else:
                 future.set_result(await main_task)
                 future.set_result(await main_task)
         except BaseException as e:
         except BaseException as e:
+            logger.exception(f'Caught an exception when running a coroutine: {e}')
             if not future.done():
             if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
 
 
@@ -243,7 +244,7 @@ class DHT(mp.Process):
                                             f" Please ensure the node is connected or specify peers=... manually."))
                                             f" Please ensure the node is connected or specify peers=... manually."))
 
 
     def declare_experts(self, uids, endpoint, wait: bool = True):
     def declare_experts(self, uids, endpoint, wait: bool = True):
-        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.",)
+        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.")
         return hivemind.declare_experts(self, uids, endpoint, wait=wait)
         return hivemind.declare_experts(self, uids, endpoint, wait=wait)
 
 
     def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,
     def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,

+ 20 - 11
hivemind/server/__init__.py

@@ -21,7 +21,7 @@ from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
 from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
 from hivemind.server.runtime import Runtime
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
+from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -153,11 +153,11 @@ class Server(threading.Thread):
         optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
         device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
 
 
-        sample_input = name_to_input[expert_cls](4, hidden_dim)
+        sample_input = name_to_input[expert_cls](3, hidden_dim)
         if isinstance(sample_input, tuple):
         if isinstance(sample_input, tuple):
-            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
         else:
         else:
-            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
 
         scheduler = schedule_name_to_scheduler[scheduler]
         scheduler = schedule_name_to_scheduler[scheduler]
 
 
@@ -167,8 +167,6 @@ class Server(threading.Thread):
             expert = name_to_block[expert_cls](hidden_dim)
             expert = name_to_block[expert_cls](hidden_dim)
             experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
             experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                          args_schema=args_schema,
                                                          args_schema=args_schema,
-                                                         outputs_schema=hivemind.BatchTensorDescriptor(
-                                                             hidden_dim, compression=compression),
                                                          optimizer=optim_cls(expert.parameters()),
                                                          optimizer=optim_cls(expert.parameters()),
                                                          scheduler=scheduler,
                                                          scheduler=scheduler,
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_warmup_steps=num_warmup_steps,
@@ -264,11 +262,15 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-
     try:
     try:
         runner.start()
         runner.start()
-        yield pipe.recv()  # once the server is ready, runner will send us a tuple(hostname, port, dht port)
-        pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+        # once the server is ready, runner will send us either (False, exception) or (True, (server_port, dht_port))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
     finally:
     finally:
         runner.join(timeout=shutdown_timeout)
         runner.join(timeout=shutdown_timeout)
         if runner.is_alive():
         if runner.is_alive():
@@ -278,14 +280,21 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
 
 
 
 
 def _server_runner(pipe, *args, **kwargs):
 def _server_runner(pipe, *args, **kwargs):
-    server = Server.create(*args, start=True, **kwargs)
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f'{type(e).__name__} {e}'))
+        return
+
     try:
     try:
         if server.dht is not None:
         if server.dht is not None:
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
         else:
         else:
             dht_listen_on = None
             dht_listen_on = None
-        pipe.send((server.listen_on, dht_listen_on))
+        pipe.send((True, (server.listen_on, dht_listen_on)))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
+
     finally:
     finally:
         logger.info("Shutting down server...")
         logger.info("Shutting down server...")
         server.shutdown()
         server.shutdown()

+ 2 - 5
hivemind/server/layers/__init__.py

@@ -1,12 +1,9 @@
-import torch
-
 name_to_block = {}
 name_to_block = {}
 name_to_input = {}
 name_to_input = {}
 
 
-from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
-from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
-
 import hivemind.server.layers.common
 import hivemind.server.layers.common
 import hivemind.server.layers.dropout
 import hivemind.server.layers.dropout
+from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 
 schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}
 schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 22 - 1
hivemind/server/layers/common.py

@@ -1,3 +1,5 @@
+import time
+
 import torch
 import torch
 from torch import nn as nn
 from torch import nn as nn
 
 
@@ -11,6 +13,8 @@ def gelu_fast(x):
 
 
 
 
 ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
 ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+
+
 @register_expert_class('ffn', ffn_sample_input)
 @register_expert_class('ffn', ffn_sample_input)
 class FeedforwardBlock(nn.Module):
 class FeedforwardBlock(nn.Module):
 
 
@@ -65,7 +69,9 @@ class TransformerEncoderLayer(nn.Module):
 
 
 transformer_sample_input = lambda batch_size, hid_dim: \
 transformer_sample_input = lambda batch_size, hid_dim: \
     (torch.empty((batch_size, 128, hid_dim)), \
     (torch.empty((batch_size, 128, hid_dim)), \
-    torch.empty((batch_size, 128), dtype=torch.bool))
+     torch.empty((batch_size, 128), dtype=torch.bool))
+
+
 @register_expert_class('transformer', transformer_sample_input)
 @register_expert_class('transformer', transformer_sample_input)
 class TunedTransformer(TransformerEncoderLayer):
 class TunedTransformer(TransformerEncoderLayer):
 
 
@@ -74,6 +80,8 @@ class TunedTransformer(TransformerEncoderLayer):
 
 
 
 
 nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
 nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+
+
 @register_expert_class('nop', nop_sample_input)
 @register_expert_class('nop', nop_sample_input)
 class NopExpert(nn.Sequential):
 class NopExpert(nn.Sequential):
 
 
@@ -83,3 +91,16 @@ class NopExpert(nn.Sequential):
 
 
     def forward(self, x):
     def forward(self, x):
         return x.clone()
         return x.clone()
+
+
+@register_expert_class('nop_delay', nop_sample_input)
+class DelayedNopExpert(nn.Sequential):
+
+    def __init__(self, hid_dim, delay=0.5):
+        super().__init__()
+        self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
+        self.delay = delay
+
+    def forward(self, x):
+        time.sleep(self.delay)
+        return x.clone()

+ 3 - 3
hivemind/utils/tensor_descr.py

@@ -46,7 +46,7 @@ class TensorDescriptor(DescriptorBase):
 
 
 @dataclass(repr=True, frozen=True)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
 class BatchTensorDescriptor(TensorDescriptor):
-    """ torch Tensor with a variable 0-th dimension, used to describe batched data """
+    """ torch.Tensor with a variable 0-th dimension, used to describe batched data """
 
 
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
         if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
         if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
@@ -60,9 +60,9 @@ class BatchTensorDescriptor(TensorDescriptor):
                    pin_memory=safe_check_pinned(tensor),
                    pin_memory=safe_check_pinned(tensor),
                    compression=compression if tensor.is_floating_point() else CompressionType.NONE)
                    compression=compression if tensor.is_floating_point() else CompressionType.NONE)
 
 
-    def make_empty(self, batch_size, **kwargs):
+    def make_empty(self, *batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
+        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
 
 
 
 def safe_check_pinned(tensor: torch.Tensor) -> bool:
 def safe_check_pinned(tensor: torch.Tensor) -> bool:

+ 11 - 6
tests/custom_networks.py

@@ -4,11 +4,13 @@ import torch.nn.functional as F
 
 
 from hivemind.server.layers.custom_experts import register_expert_class
 from hivemind.server.layers.custom_experts import register_expert_class
 
 
-sample_input = lambda batch_size, hidden_dim : torch.empty((batch_size, hidden_dim))
+sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))
+
+
 @register_expert_class('perceptron', sample_input)
 @register_expert_class('perceptron', sample_input)
 class MultilayerPerceptron(nn.Module):
 class MultilayerPerceptron(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
     def __init__(self, hidden_dim, num_classes=10):
-        super(MultilayerPerceptron, self).__init__()
+        super().__init__()
         self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
         self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
         self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
         self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
         self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
         self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
@@ -19,14 +21,17 @@ class MultilayerPerceptron(nn.Module):
         x = self.layer3(x)
         x = self.layer3(x)
         return x
         return x
 
 
-multihead_sample_input = lambda batch_size, hidden_dim : \
+
+multihead_sample_input = lambda batch_size, hidden_dim: \
     (torch.empty((batch_size, hidden_dim)),
     (torch.empty((batch_size, hidden_dim)),
-    torch.empty((batch_size, 2 * hidden_dim)),
-    torch.empty((batch_size, 3 * hidden_dim)),)
+     torch.empty((batch_size, 2 * hidden_dim)),
+     torch.empty((batch_size, 3 * hidden_dim)),)
+
+
 @register_expert_class('multihead', multihead_sample_input)
 @register_expert_class('multihead', multihead_sample_input)
 class MultiheadNetwork(nn.Module):
 class MultiheadNetwork(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
     def __init__(self, hidden_dim, num_classes=10):
-        super(MultiheadNetwork, self).__init__()
+        super().__init__()
         self.layer1 = nn.Linear(hidden_dim, num_classes)
         self.layer1 = nn.Linear(hidden_dim, num_classes)
         self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
         self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
         self.layer3 = nn.Linear(3 * hidden_dim, num_classes)
         self.layer3 = nn.Linear(3 * hidden_dim, num_classes)

+ 13 - 14
tests/test_custom_expert.py

@@ -1,19 +1,17 @@
 import os
 import os
-import pytest
-from typing import Optional
 
 
+import pytest
 import torch
 import torch
 
 
-import hivemind
 from hivemind import RemoteExpert, background_server
 from hivemind import RemoteExpert, background_server
 
 
+
 @pytest.mark.forked
 @pytest.mark.forked
-def test_custom_expert(port: Optional[int] = None, hid_dim=16):
+def test_custom_expert(hid_dim=16):
     with background_server(
     with background_server(
-        expert_cls='perceptron', num_experts=2, device='cpu',
-        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
-
+            expert_cls='perceptron', num_experts=2, device='cpu',
+            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+            custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
 
 
@@ -28,18 +26,19 @@ def test_custom_expert(port: Optional[int] = None, hid_dim=16):
             loss = output1.sum()
             loss = output1.sum()
             loss.backward()
             loss.backward()
 
 
+
 @pytest.mark.forked
 @pytest.mark.forked
-def test_multihead_expert(port: Optional[int] = None, hid_dim=16):
+def test_multihead_expert(hid_dim=16):
     with background_server(
     with background_server(
-        expert_cls='multihead', num_experts=2, device='cpu',
-        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
-
+            expert_cls='multihead', num_experts=2, device='cpu',
+            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+            custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
 
 
         for batch_size in (1, 4):
         for batch_size in (1, 4):
-            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim))
+            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
+                     torch.randn(batch_size, 3 * hid_dim))
 
 
             output0 = expert0(*batch)
             output0 = expert0(*batch)
             output1 = expert1(*batch)
             output1 = expert1(*batch)

+ 3 - 3
tests/test_dht_experts.py

@@ -77,7 +77,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
 @pytest.mark.forked
 @pytest.mark.forked
 def test_dht_single_node():
 def test_dht_single_node():
     node = hivemind.DHT(start=True, expiration=999)
     node = hivemind.DHT(start=True, expiration=999)
-    beam_search = MoEBeamSearcher(node, 'expert.')
+    beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
 
 
     assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
     assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
     assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
     assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
@@ -104,7 +104,7 @@ def test_dht_single_node():
     assert initial_beam[2][:2] == (0.0, 'expert.3.')
     assert initial_beam[2][:2] == (0.0, 'expert.3.')
 
 
     with pytest.raises(AssertionError):
     with pytest.raises(AssertionError):
-        beam_search = MoEBeamSearcher(node, 'expert.1.ffn')
+        beam_search = MoEBeamSearcher(node, 'expert.1.ffn', (2, 2))
 
 
     with pytest.raises(AssertionError):
     with pytest.raises(AssertionError):
         beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
         beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
@@ -147,7 +147,7 @@ async def test_negative_caching():
 
 
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
-    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', negative_caching=True)
+    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', grid_size=(10, 10, 10), negative_caching=True)
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
     assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 
 

+ 21 - 3
tests/test_moe.py

@@ -18,13 +18,30 @@ def test_moe():
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
 
 
         dmoe = hivemind.RemoteMixtureOfExperts(
         dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
+            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
 
 
-        for i in range(5):
+        for i in range(3):
             out = dmoe(torch.randn(10, 16))
             out = dmoe(torch.randn(10, 16))
             out.sum().backward()
             out.sum().backward()
 
 
 
 
+@pytest.mark.forked
+def test_no_experts():
+    all_expert_uids = [f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
+                       for _ in range(10)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
+                           hidden_dim=16) as (server_endpoint, dht_endpoint):
+        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+            in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
+            backward_timeout=0.1, allow_zero_outputs=True)
+
+        for i in range(3):
+            out, balancing_loss = dmoe(torch.randn(10, 16))
+            out.sum().backward()
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_call_many(hidden_dim=16):
 def test_call_many(hidden_dim=16):
     k_min = 1
     k_min = 1
@@ -33,6 +50,7 @@ def test_call_many(hidden_dim=16):
     forward_timeout = None
     forward_timeout = None
     backward_timeout = None
     backward_timeout = None
     detect_anomalies = False
     detect_anomalies = False
+    allow_zero_outputs = False
     atol = 1e-5
     atol = 1e-5
 
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
@@ -44,7 +62,7 @@ def test_call_many(hidden_dim=16):
 
 
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
-            forward_timeout, backward_timeout, detect_anomalies, e1.info, inputs
+            forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs
         )
         )
         assert mask.shape == (4, 3)
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, hidden_dim)
         assert expert_outputs.shape == (4, 3, hidden_dim)

+ 75 - 3
tests/test_training.py

@@ -1,13 +1,14 @@
+import time
 from functools import partial
 from functools import partial
 
 
-import time
 import pytest
 import pytest
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from sklearn.datasets import load_digits
 from sklearn.datasets import load_digits
 
 
-from hivemind import RemoteExpert, background_server, DHT, DecentralizedSGD
+from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
+    DecentralizedSGD
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -22,20 +23,91 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         expert2 = RemoteExpert('expert.1', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
 
-        opt = torch.optim.SGD(model.parameters(), lr=0.05)
+        opt = SGD(model.parameters(), lr=0.05)
 
 
         for step in range(max_steps):
         for step in range(max_steps):
+            outputs = model(X_train)
+            loss = F.cross_entropy(outputs, y_train)
+            loss.backward()
+            opt.step()
             opt.zero_grad()
             opt.zero_grad()
 
 
+            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
+            if accuracy >= threshold:
+                break
+
+        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
+
+
+@pytest.mark.forked
+def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
+    dataset = load_digits(n_class=2)
+    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    SGD = partial(torch.optim.SGD, lr=0.05)
+
+    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
+            as (server_endpoint, dht_endpoint):
+        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
+        model = nn.Sequential(moe, nn.Linear(64, 2))
+
+        opt = SGD(model.parameters(), lr=0.05)
+
+        for step in range(max_steps):
             outputs = model(X_train)
             outputs = model(X_train)
             loss = F.cross_entropy(outputs, y_train)
             loss = F.cross_entropy(outputs, y_train)
             loss.backward()
             loss.backward()
             opt.step()
             opt.step()
+            opt.zero_grad()
+
+            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
+            if accuracy >= threshold:
+                break
+
+        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
+
+
+class SwitchNetwork(nn.Module):
+    def __init__(self, dht, in_features, num_classes, num_experts):
+        super().__init__()
+        self.moe = RemoteSwitchMixtureOfExperts(in_features=in_features, grid_size=(num_experts,), dht=dht,
+                                                jitter_eps=0, uid_prefix='expert.', k_best=1,
+                                                k_min=1)
+        self.linear = nn.Linear(in_features, num_classes)
+
+    def forward(self, x):
+        moe_output, balancing_loss = self.moe(x)
+        return self.linear(moe_output), balancing_loss
+
+
+@pytest.mark.forked
+def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
+    dataset = load_digits(n_class=2)
+    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    SGD = partial(torch.optim.SGD, lr=0.05)
+
+    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
+                           num_handlers=1) as (server_endpoint, dht_endpoint):
+        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        model = SwitchNetwork(dht, 64, 2, num_experts)
+        opt = SGD(model.parameters(), lr=0.05)
+
+        for step in range(max_steps):
+            outputs, balancing_loss = model(X_train)
+            loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
+            loss.backward()
+            opt.step()
+            opt.zero_grad()
 
 
             accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
             accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
             if accuracy >= threshold:
             if accuracy >= threshold:
                 break
                 break
 
 
+        assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"