|
@@ -81,8 +81,11 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
|
|
|
if self._expert_info is None:
|
|
|
try:
|
|
|
self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
|
|
|
+ except StopIteration:
|
|
|
+ raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
|
|
|
+ "the grid size are consistent with running Server instances.")
|
|
|
except grpc.RpcError as e:
|
|
|
- logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
|
|
|
+ logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
|
|
|
|
|
|
expert_mask, *expert_outputs = _RemoteCallMany.apply(
|
|
|
DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
|
|
@@ -96,8 +99,8 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
|
|
|
|
|
|
# compute expert probabilities as product across grid dimensions
|
|
|
expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
|
|
|
- masked_logits = torch.full((1,), float('-inf'), device=expert_probs.device, dtype=expert_probs.dtype)
|
|
|
- expert_probs = torch.where(expert_mask, expert_probs, masked_logits)
|
|
|
+ masked_probs = torch.zeros((1,), device=expert_probs.device, dtype=expert_probs.dtype)
|
|
|
+ expert_probs = torch.where(expert_mask, expert_probs, masked_probs)
|
|
|
|
|
|
# multiply outputs by expert probabilities
|
|
|
averaged_outputs_flat = [
|
|
@@ -127,7 +130,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
|
|
|
for dim_size in self.beam_search.grid_size]
|
|
|
|
|
|
# out of chosen_experts, select those for which expert_mask is True
|
|
|
- for (sample_idx, expert_idx) in expert_mask.nonzero().numpy():
|
|
|
+ for (sample_idx, expert_idx) in expert_mask.nonzero().cpu().numpy():
|
|
|
expert = batch_experts[sample_idx][expert_idx]
|
|
|
expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
|
|
|
expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
|