瀏覽代碼

Fix incorrect data types/values in RemoteSwitchMixtureOfExperts (#246)

* Fix incorrect dtypes/values in RemoteSwitchMixtureOfExperts

* Raise RuntimeError when no experts are found

* Reorder exception handling
Max Ryabinin 4 年之前
父節點
當前提交
0a1fdb172f
共有 2 個文件被更改,包括 10 次插入4 次删除
  1. 3 0
      hivemind/client/moe.py
  2. 7 4
      hivemind/client/switch_moe.py

+ 3 - 0
hivemind/client/moe.py

@@ -85,6 +85,9 @@ class RemoteMixtureOfExperts(nn.Module):
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
+            except StopIteration:
+                raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
+                                   "the grid size are consistent with running Server instances.")
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 

+ 7 - 4
hivemind/client/switch_moe.py

@@ -81,8 +81,11 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
+            except StopIteration:
+                raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
+                                   "the grid size are consistent with running Server instances.")
             except grpc.RpcError as e:
-                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
+                logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
@@ -96,8 +99,8 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
         # compute expert probabilities as product across grid dimensions
         expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
-        masked_logits = torch.full((1,), float('-inf'), device=expert_probs.device, dtype=expert_probs.dtype)
-        expert_probs = torch.where(expert_mask, expert_probs, masked_logits)
+        masked_probs = torch.zeros((1,), device=expert_probs.device, dtype=expert_probs.dtype)
+        expert_probs = torch.where(expert_mask, expert_probs, masked_probs)
 
         # multiply outputs by expert probabilities
         averaged_outputs_flat = [
@@ -127,7 +130,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                              for dim_size in self.beam_search.grid_size]
 
         # out of chosen_experts, select those for which expert_mask is True
-        for (sample_idx, expert_idx) in expert_mask.nonzero().numpy():
+        for (sample_idx, expert_idx) in expert_mask.nonzero().cpu().numpy():
             expert = batch_experts[sample_idx][expert_idx]
             expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
             expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))