Эх сурвалжийг харах

Add Switch Transformers-like RemoteMixtureOfExperts (#228)

* Add RemoteSwitchMixtureOfExperts

* Add load balancing and test for training

* Make grid_size non-optional

* Support passing *args to batch_size

* Reformat tests/custom_networks.py

* Reformat test_custom_expert.py

* Add DelayedNopExpert

* Add allow_zero_outputs to RemoteMixtureOfExperts

* Fix exception handling in DHT, log tracebacks in async_run_coroutine

* Generate BatchTensorDescriptor from dummy inputs on server start
This ensures that outputs_schema contains correct attributes such as dtype

* Handle exceptions on start of background_server

* Add test_no_experts
Max Ryabinin 4 жил өмнө
parent
commit
62652e1717

+ 4 - 0
docs/modules/client.rst

@@ -18,6 +18,10 @@
    :members:
    :member-order: bysource
 
+.. autoclass:: RemoteSwitchMixtureOfExperts
+   :members:
+   :member-order: bysource
+
 .. autoclass:: DecentralizedAverager
    :members:
    :member-order: bysource

+ 1 - 0
hivemind/client/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
+from hivemind.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging.training import TrainingAverager

+ 3 - 2
hivemind/client/beam_search.py

@@ -63,7 +63,7 @@ class MoEBeamSearcher:
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
     """
 
-    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Optional[Tuple[int, ...]] = None,
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Tuple[int, ...],
                  num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
         if not uid_prefix.endswith(UID_DELIMITER):
             uid_prefix += UID_DELIMITER
@@ -71,6 +71,7 @@ class MoEBeamSearcher:
         assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
         self.dht = dht
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
+        self.total_grid_size = sum(grid_size)
         self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
 
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
@@ -174,7 +175,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
-        assert (not self.grid_size or len(grid_scores) == len(self.grid_size)) and beam_size > 0
+        assert len(grid_scores) == len(self.grid_size) and beam_size > 0
         return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
                                               grid_scores=list(grid_scores), negative_caching=self.negative_caching,
                                               num_workers=self.num_workers), return_future)

+ 56 - 38
hivemind/client/moe.py

@@ -14,7 +14,7 @@ from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_pack, nested_flatten, nested_map
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
@@ -23,7 +23,7 @@ logger = get_logger(__name__)
 
 class RemoteMixtureOfExperts(nn.Module):
     """
-    A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
+    A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
     Natively supports pytorch autograd.
 
     :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
@@ -38,14 +38,15 @@ class RemoteMixtureOfExperts(nn.Module):
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
-    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
      Any expert that didn't manage to return output after that delay is considered unavailable
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
+    :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
     """
 
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
-                 **dht_kwargs):
+                 allow_zero_outputs: bool = False, **dht_kwargs):
         super().__init__()
         self.dht = dht
         self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
@@ -53,8 +54,10 @@ class RemoteMixtureOfExperts(nn.Module):
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.detect_anomalies = detect_anomalies
+        self.allow_zero_outputs = allow_zero_outputs
 
-        self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
+        # jointly predict logits for all grid dimensions
+        self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
         self._expert_info = None  # expert['info'] from one of experts in the grid
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
@@ -87,7 +90,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
+            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
+            *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -97,6 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
         averaged_outputs_flat = [
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+
         return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
 
     def compute_expert_scores(
@@ -152,13 +157,14 @@ class _RemoteCallMany(torch.autograd.Function):
     one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
 
     Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
-          experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
+          experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
     """
 
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any],
+                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
 
@@ -181,32 +187,42 @@ class _RemoteCallMany(torch.autograd.Function):
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
 
-        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
+        responded_inds, alive_flat_outputs = cls._collect_responses(
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
-        if len(alive_grid_indices) == 0:
-            raise TimeoutError("Forward pass: no alive experts responded within timeout.")
+        if len(responded_inds) < k_min:
+            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
+
+        if not isinstance(info['outputs_schema'], tuple):
+            outputs_schema = (info['outputs_schema'],)
+        else:
+            outputs_schema = info['outputs_schema']
+        outputs = nested_map(
+            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            outputs_schema)
 
         # assemble responses
-        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
-        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
-        mask[alive_ii, alive_jj] = True
+        if len(responded_inds) > 0 or allow_zero_outputs:
+            batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
+                                          list(zip(*responded_inds)) or ([], []))
 
-        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
-        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+            alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
+            # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
 
-        outputs = []
-        for response_stacked in alive_flat_outputs_stacked:
-            output = torch.zeros(
-                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
-                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
-            output[alive_ii, alive_jj] = response_stacked
-            outputs.append(output.to(flat_inputs[0].device))
+            for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
+                output[batch_inds, expert_inds] = response_stacked.to(output.device)
+
+        else:
+            raise RuntimeError('Forward pass: 0 experts responded within timeout and allow_zero_outputs is False')
+
+        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
+        mask[batch_inds, expert_inds] = True
 
         # save individual outputs for backward pass
-        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
+        ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
         ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
                                   detect_anomalies)
-        return (mask,) + tuple(outputs)
+
+        return (mask,) + outputs
 
     @classmethod
     @once_differentiable
@@ -235,35 +251,37 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                     inputs_per_expert, grad_outputs_per_expert):
             expert = expert_per_sample[i.item()][j.item()]
-            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+            stub = _get_expert_stub(expert.endpoint)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
                                   for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
             new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
             pending_tasks[new_task] = (i, j)
 
-        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
+        survivor_inds, survivor_grad_inputs = cls._collect_responses(
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
-        if len(backward_survivor_indices) == 0:
-            raise TimeoutError("Backward pass: no alive experts responded within timeout.")
+        if len(survivor_inds) < backward_k_min:
+            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
 
         # assemble responses
-        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
+        batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, dtype=torch.long),
+                                      list(zip(*survivor_inds)) or ([], []))
 
         survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
-        grad_inputs = []
-        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
+        grad_inputs = nested_map(
+            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            list(nested_flatten(info['forward_schema'])))
+
+        for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
             grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
-                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
+                (num_samples, max_experts, *grad_input.shape[1:]),
                 device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
-            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
-
-            # sum gradients from each expert
-            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
+            grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
+            grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
-        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
+        return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
 
     @staticmethod
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,

+ 175 - 0
hivemind/client/switch_moe.py

@@ -0,0 +1,175 @@
+from __future__ import annotations
+
+from typing import Tuple, List
+
+import grpc
+import torch
+
+from hivemind.client.expert import RemoteExpert, DUMMY
+from hivemind.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
+    """
+    A module implementing Switch Transformers [1] Mixture-of-Experts inference with remote experts.
+
+    [1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
+     William Fedus, Barret Zoph, Noam Shazeer. https://arxiv.org/abs/2101.03961
+
+    :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
+     forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
+     the missing experts
+
+    :param in_features: common input size for experts and gating function
+    :param grid_size: dimensions that form expert uid (see below)
+    :param uid_prefix: common prefix for all expert uids (must end with '.')
+    :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
+    :param dht: a DHT instance used to search for best experts
+    :param k_best: average this many highest-scoring experts to compute activations
+    :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
+    :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
+     Any expert that didn't manage to return output after that delay is considered unavailable
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
+    :param allow_zero_outputs: whether to return just the input if no experts respond on forward pass
+    """
+
+    def __init__(self, *, grid_size: Tuple[int, ...], utilization_alpha: float = 0.9, grid_dropout: float = 1.0,
+                 jitter_eps: float = 1e-2, k_best=1, k_min=0, backward_k_min=0, allow_zero_outputs=True, **kwargs):
+        super().__init__(grid_size=grid_size, k_best=k_best, k_min=k_min, backward_k_min=backward_k_min,
+                         allow_zero_outputs=allow_zero_outputs, **kwargs)
+
+        initial_utilization = torch.cat(
+            [torch.tensor([1 / dim_size for _ in range(dim_size)], dtype=torch.float)
+             for dim_size in grid_size],
+        )
+        self.register_buffer('grid_utilization', initial_utilization)
+        self.utilization_alpha = utilization_alpha
+        self.grid_dropout = grid_dropout
+        self.jitter_eps = jitter_eps
+
+    def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
+        if input.ndim != 2:
+            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
+        else:
+            input_for_gating = input
+
+        # Multiplicative jitter for regularized routing
+        jitter_noise = torch.empty_like(input_for_gating).uniform_(1 - self.jitter_eps, 1 + self.jitter_eps)
+        input_for_gating *= jitter_noise
+
+        # Compute scores, find most appropriate experts with beam search
+        grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
+
+        grid_dropout_masks = (
+            (torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
+             < self.grid_dropout) for dim_size in self.beam_search.grid_size
+        )
+        grid_scores_dropout = [torch.where(dropout_mask, grid_score,
+                                           torch.full((1,), float('-inf'), device=grid_score.device,
+                                                      dtype=grid_score.dtype))
+                               for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)]
+
+        grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
+        chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
+            [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best)
+
+        if self._expert_info is None:
+            try:
+                self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
+            except grpc.RpcError as e:
+                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
+
+        expert_mask, *expert_outputs = _RemoteCallMany.apply(
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
+            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
+            *nested_flatten(((input, *args), kwargs)))
+        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
+
+        batch_utilization = self._compute_batch_utilization(chosen_experts, expert_mask)
+        self.grid_utilization = \
+            self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization
+
+        # compute expert probabilities as product across grid dimensions
+        expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
+        masked_logits = torch.full((1,), float('-inf'), device=expert_probs.device, dtype=expert_probs.dtype)
+        expert_probs = torch.where(expert_mask, expert_probs, masked_logits)
+
+        # multiply outputs by expert probabilities
+        averaged_outputs_flat = [
+            (expert_probs[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
+            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+
+        packed_outputs = nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
+
+        # Load balancing loss: multiply fractions of probability mass and fractions of routed examples
+        # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
+        balancing_loss = torch.stack([torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
+                                      for dim_softmax, dim_utilization, dim_size in
+                                      zip(grid_softmax, self.grid_utilization, self.beam_search.grid_size)]).sum()
+
+        # residual connection
+        if isinstance(packed_outputs, torch.Tensor):
+            packed_outputs = packed_outputs + input
+        else:
+            packed_outputs[0] = packed_outputs[0] + input
+
+        return packed_outputs, balancing_loss
+
+    @torch.no_grad()
+    def _compute_batch_utilization(self, batch_experts, expert_mask):
+        batch_utilization = [torch.zeros((dim_size,), dtype=self.grid_utilization.dtype,
+                                         device=self.grid_utilization.device)
+                             for dim_size in self.beam_search.grid_size]
+
+        # out of chosen_experts, select those for which expert_mask is True
+        for (sample_idx, expert_idx) in expert_mask.nonzero().numpy():
+            expert = batch_experts[sample_idx][expert_idx]
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
+
+            for dim_index, dim_utilization in zip(expert_indices, batch_utilization):
+                dim_utilization[dim_index] += 1
+
+        return torch.cat([
+            torch.nn.functional.normalize(dim_utilization, p=1, dim=0)
+            for dim_utilization in batch_utilization
+        ])
+
+    def compute_expert_scores(
+            self, grid_probs: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        """
+        Compute scores for each expert by multiplying grid probabilities, autograd-friendly
+        :param grid_probs: list of torch tensors, i-th tensor contains scores for i-th grid dimension
+        :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
+        :returns: a tensor of scores, float32[batch_size, k]
+        :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
+        """
+        expert_counts = list(map(len, batch_experts))
+        batch_size = len(batch_experts)
+        max_num_experts = max(expert_counts)
+        total_num_experts = sum(expert_counts)
+        expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]
+        flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
+        flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
+        flat_experts = [expert for row in batch_experts for expert in row]
+
+        grid_indices = torch.zeros([len(flat_experts), len(grid_probs)], dtype=torch.int64)
+        for i, expert in enumerate(flat_experts):
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
+            grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
+
+        scores_per_dim = [
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
+        flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
+
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
+        scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
+        return scores

+ 3 - 2
hivemind/dht/__init__.py

@@ -91,7 +91,7 @@ class DHT(mp.Process):
         """
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
 
     def shutdown(self) -> None:
         """ Shut down a running dht process """
@@ -186,6 +186,7 @@ class DHT(mp.Process):
             else:
                 future.set_result(await main_task)
         except BaseException as e:
+            logger.exception(f'Caught an exception when running a coroutine: {e}')
             if not future.done():
                 future.set_exception(e)
 
@@ -243,7 +244,7 @@ class DHT(mp.Process):
                                             f" Please ensure the node is connected or specify peers=... manually."))
 
     def declare_experts(self, uids, endpoint, wait: bool = True):
-        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.",)
+        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.")
         return hivemind.declare_experts(self, uids, endpoint, wait=wait)
 
     def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,

+ 20 - 11
hivemind/server/__init__.py

@@ -21,7 +21,7 @@ from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
+from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
 
 logger = get_logger(__name__)
@@ -153,11 +153,11 @@ class Server(threading.Thread):
         optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
 
-        sample_input = name_to_input[expert_cls](4, hidden_dim)
+        sample_input = name_to_input[expert_cls](3, hidden_dim)
         if isinstance(sample_input, tuple):
-            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
         else:
-            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
         scheduler = schedule_name_to_scheduler[scheduler]
 
@@ -167,8 +167,6 @@ class Server(threading.Thread):
             expert = name_to_block[expert_cls](hidden_dim)
             experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                          args_schema=args_schema,
-                                                         outputs_schema=hivemind.BatchTensorDescriptor(
-                                                             hidden_dim, compression=compression),
                                                          optimizer=optim_cls(expert.parameters()),
                                                          scheduler=scheduler,
                                                          num_warmup_steps=num_warmup_steps,
@@ -264,11 +262,15 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-
     try:
         runner.start()
-        yield pipe.recv()  # once the server is ready, runner will send us a tuple(hostname, port, dht port)
-        pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+        # once the server is ready, runner will send us either (False, exception) or (True, (server_port, dht_port))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
     finally:
         runner.join(timeout=shutdown_timeout)
         if runner.is_alive():
@@ -278,14 +280,21 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
 
 
 def _server_runner(pipe, *args, **kwargs):
-    server = Server.create(*args, start=True, **kwargs)
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f'{type(e).__name__} {e}'))
+        return
+
     try:
         if server.dht is not None:
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
         else:
             dht_listen_on = None
-        pipe.send((server.listen_on, dht_listen_on))
+        pipe.send((True, (server.listen_on, dht_listen_on)))
         pipe.recv()  # wait for shutdown signal
+
     finally:
         logger.info("Shutting down server...")
         server.shutdown()

+ 2 - 5
hivemind/server/layers/__init__.py

@@ -1,12 +1,9 @@
-import torch
-
 name_to_block = {}
 name_to_input = {}
 
-from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
-from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
-
 import hivemind.server.layers.common
 import hivemind.server.layers.dropout
+from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 22 - 1
hivemind/server/layers/common.py

@@ -1,3 +1,5 @@
+import time
+
 import torch
 from torch import nn as nn
 
@@ -11,6 +13,8 @@ def gelu_fast(x):
 
 
 ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+
+
 @register_expert_class('ffn', ffn_sample_input)
 class FeedforwardBlock(nn.Module):
 
@@ -65,7 +69,9 @@ class TransformerEncoderLayer(nn.Module):
 
 transformer_sample_input = lambda batch_size, hid_dim: \
     (torch.empty((batch_size, 128, hid_dim)), \
-    torch.empty((batch_size, 128), dtype=torch.bool))
+     torch.empty((batch_size, 128), dtype=torch.bool))
+
+
 @register_expert_class('transformer', transformer_sample_input)
 class TunedTransformer(TransformerEncoderLayer):
 
@@ -74,6 +80,8 @@ class TunedTransformer(TransformerEncoderLayer):
 
 
 nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+
+
 @register_expert_class('nop', nop_sample_input)
 class NopExpert(nn.Sequential):
 
@@ -83,3 +91,16 @@ class NopExpert(nn.Sequential):
 
     def forward(self, x):
         return x.clone()
+
+
+@register_expert_class('nop_delay', nop_sample_input)
+class DelayedNopExpert(nn.Sequential):
+
+    def __init__(self, hid_dim, delay=0.5):
+        super().__init__()
+        self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
+        self.delay = delay
+
+    def forward(self, x):
+        time.sleep(self.delay)
+        return x.clone()

+ 3 - 3
hivemind/utils/tensor_descr.py

@@ -46,7 +46,7 @@ class TensorDescriptor(DescriptorBase):
 
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
-    """ torch Tensor with a variable 0-th dimension, used to describe batched data """
+    """ torch.Tensor with a variable 0-th dimension, used to describe batched data """
 
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
         if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
@@ -60,9 +60,9 @@ class BatchTensorDescriptor(TensorDescriptor):
                    pin_memory=safe_check_pinned(tensor),
                    compression=compression if tensor.is_floating_point() else CompressionType.NONE)
 
-    def make_empty(self, batch_size, **kwargs):
+    def make_empty(self, *batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
+        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
 
 def safe_check_pinned(tensor: torch.Tensor) -> bool:

+ 11 - 6
tests/custom_networks.py

@@ -4,11 +4,13 @@ import torch.nn.functional as F
 
 from hivemind.server.layers.custom_experts import register_expert_class
 
-sample_input = lambda batch_size, hidden_dim : torch.empty((batch_size, hidden_dim))
+sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))
+
+
 @register_expert_class('perceptron', sample_input)
 class MultilayerPerceptron(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
-        super(MultilayerPerceptron, self).__init__()
+        super().__init__()
         self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
         self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
         self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
@@ -19,14 +21,17 @@ class MultilayerPerceptron(nn.Module):
         x = self.layer3(x)
         return x
 
-multihead_sample_input = lambda batch_size, hidden_dim : \
+
+multihead_sample_input = lambda batch_size, hidden_dim: \
     (torch.empty((batch_size, hidden_dim)),
-    torch.empty((batch_size, 2 * hidden_dim)),
-    torch.empty((batch_size, 3 * hidden_dim)),)
+     torch.empty((batch_size, 2 * hidden_dim)),
+     torch.empty((batch_size, 3 * hidden_dim)),)
+
+
 @register_expert_class('multihead', multihead_sample_input)
 class MultiheadNetwork(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
-        super(MultiheadNetwork, self).__init__()
+        super().__init__()
         self.layer1 = nn.Linear(hidden_dim, num_classes)
         self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
         self.layer3 = nn.Linear(3 * hidden_dim, num_classes)

+ 13 - 14
tests/test_custom_expert.py

@@ -1,19 +1,17 @@
 import os
-import pytest
-from typing import Optional
 
+import pytest
 import torch
 
-import hivemind
 from hivemind import RemoteExpert, background_server
 
+
 @pytest.mark.forked
-def test_custom_expert(port: Optional[int] = None, hid_dim=16):
+def test_custom_expert(hid_dim=16):
     with background_server(
-        expert_cls='perceptron', num_experts=2, device='cpu',
-        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
-
+            expert_cls='perceptron', num_experts=2, device='cpu',
+            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+            custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
 
@@ -28,18 +26,19 @@ def test_custom_expert(port: Optional[int] = None, hid_dim=16):
             loss = output1.sum()
             loss.backward()
 
+
 @pytest.mark.forked
-def test_multihead_expert(port: Optional[int] = None, hid_dim=16):
+def test_multihead_expert(hid_dim=16):
     with background_server(
-        expert_cls='multihead', num_experts=2, device='cpu',
-        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
-
+            expert_cls='multihead', num_experts=2, device='cpu',
+            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+            custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
         expert0 = RemoteExpert('expert.0', server_endpoint)
         expert1 = RemoteExpert('expert.1', server_endpoint)
 
         for batch_size in (1, 4):
-            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim))
+            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
+                     torch.randn(batch_size, 3 * hid_dim))
 
             output0 = expert0(*batch)
             output1 = expert1(*batch)

+ 3 - 3
tests/test_dht_experts.py

@@ -77,7 +77,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
 @pytest.mark.forked
 def test_dht_single_node():
     node = hivemind.DHT(start=True, expiration=999)
-    beam_search = MoEBeamSearcher(node, 'expert.')
+    beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
 
     assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
     assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
@@ -104,7 +104,7 @@ def test_dht_single_node():
     assert initial_beam[2][:2] == (0.0, 'expert.3.')
 
     with pytest.raises(AssertionError):
-        beam_search = MoEBeamSearcher(node, 'expert.1.ffn')
+        beam_search = MoEBeamSearcher(node, 'expert.1.ffn', (2, 2))
 
     with pytest.raises(AssertionError):
         beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
@@ -147,7 +147,7 @@ async def test_negative_caching():
 
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
-    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', negative_caching=True)
+    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', grid_size=(10, 10, 10), negative_caching=True)
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 

+ 21 - 3
tests/test_moe.py

@@ -18,13 +18,30 @@ def test_moe():
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
 
         dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
+            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
 
-        for i in range(5):
+        for i in range(3):
             out = dmoe(torch.randn(10, 16))
             out.sum().backward()
 
 
+@pytest.mark.forked
+def test_no_experts():
+    all_expert_uids = [f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
+                       for _ in range(10)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
+                           hidden_dim=16) as (server_endpoint, dht_endpoint):
+        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+            in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
+            backward_timeout=0.1, allow_zero_outputs=True)
+
+        for i in range(3):
+            out, balancing_loss = dmoe(torch.randn(10, 16))
+            out.sum().backward()
+
+
 @pytest.mark.forked
 def test_call_many(hidden_dim=16):
     k_min = 1
@@ -33,6 +50,7 @@ def test_call_many(hidden_dim=16):
     forward_timeout = None
     backward_timeout = None
     detect_anomalies = False
+    allow_zero_outputs = False
     atol = 1e-5
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
@@ -44,7 +62,7 @@ def test_call_many(hidden_dim=16):
 
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
-            forward_timeout, backward_timeout, detect_anomalies, e1.info, inputs
+            forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs
         )
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, hidden_dim)

+ 75 - 3
tests/test_training.py

@@ -1,13 +1,14 @@
+import time
 from functools import partial
 
-import time
 import pytest
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
-from hivemind import RemoteExpert, background_server, DHT, DecentralizedSGD
+from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
+    DecentralizedSGD
 
 
 @pytest.mark.forked
@@ -22,20 +23,91 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
-        opt = torch.optim.SGD(model.parameters(), lr=0.05)
+        opt = SGD(model.parameters(), lr=0.05)
 
         for step in range(max_steps):
+            outputs = model(X_train)
+            loss = F.cross_entropy(outputs, y_train)
+            loss.backward()
+            opt.step()
             opt.zero_grad()
 
+            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
+            if accuracy >= threshold:
+                break
+
+        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
+
+
+@pytest.mark.forked
+def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
+    dataset = load_digits(n_class=2)
+    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    SGD = partial(torch.optim.SGD, lr=0.05)
+
+    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
+            as (server_endpoint, dht_endpoint):
+        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
+        model = nn.Sequential(moe, nn.Linear(64, 2))
+
+        opt = SGD(model.parameters(), lr=0.05)
+
+        for step in range(max_steps):
             outputs = model(X_train)
             loss = F.cross_entropy(outputs, y_train)
             loss.backward()
             opt.step()
+            opt.zero_grad()
+
+            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
+            if accuracy >= threshold:
+                break
+
+        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
+
+
+class SwitchNetwork(nn.Module):
+    def __init__(self, dht, in_features, num_classes, num_experts):
+        super().__init__()
+        self.moe = RemoteSwitchMixtureOfExperts(in_features=in_features, grid_size=(num_experts,), dht=dht,
+                                                jitter_eps=0, uid_prefix='expert.', k_best=1,
+                                                k_min=1)
+        self.linear = nn.Linear(in_features, num_classes)
+
+    def forward(self, x):
+        moe_output, balancing_loss = self.moe(x)
+        return self.linear(moe_output), balancing_loss
+
+
+@pytest.mark.forked
+def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
+    dataset = load_digits(n_class=2)
+    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    SGD = partial(torch.optim.SGD, lr=0.05)
+
+    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
+                           num_handlers=1) as (server_endpoint, dht_endpoint):
+        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+
+        model = SwitchNetwork(dht, 64, 2, num_experts)
+        opt = SGD(model.parameters(), lr=0.05)
+
+        for step in range(max_steps):
+            outputs, balancing_loss = model(X_train)
+            loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
+            loss.backward()
+            opt.step()
+            opt.zero_grad()
 
             accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
             if accuracy >= threshold:
                 break
 
+        assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"