|
@@ -148,14 +148,15 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
def compute_expert_scores(self, grid_scores: List[torch.Tensor],
|
|
|
batch_experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
|
- flat_experts = [expert for row in batch_experts for expert in row]
|
|
|
expert_counts = list(map(len, batch_experts))
|
|
|
+ batch_size = len(batch_experts)
|
|
|
max_num_experts = max(expert_counts)
|
|
|
total_num_experts = sum(expert_counts)
|
|
|
expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
|
|
|
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
|
|
|
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
|
|
|
- flat_local_indices = expert_index_in_batch - expert_strides[expert_row_ix]
|
|
|
+ flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
|
|
|
+ flat_experts = [expert for row in batch_experts for expert in row]
|
|
|
|
|
|
grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
for i, expert in enumerate(flat_experts):
|
|
@@ -169,8 +170,6 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
|
|
|
|
|
|
|
|
|
- batch_size = len(batch_experts)
|
|
|
- max_num_experts = max(map(len, batch_experts))
|
|
|
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
|
|
|
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
|
|
|
return scores
|