Răsfoiți Sursa

vectorize compute_scores better

justheuristic 5 ani în urmă
părinte
comite
57629b7bc8
1 a modificat fișierele cu 17 adăugiri și 11 ștergeri
  1. 17 11
      tesseract/client/moe.py

+ 17 - 11
tesseract/client/moe.py

@@ -61,7 +61,7 @@ class RemoteMixtureOfExperts(nn.Module):
         :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
         :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
         """
-        if self.allow_broadcasting and input.shape != 2:
+        if self.allow_broadcasting and input.ndim != 2:
             # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
             flattened_dims = input.shape[:-1]
             input_flat = input.view(-1, input.shape[-1])
@@ -75,7 +75,7 @@ class RemoteMixtureOfExperts(nn.Module):
         chosen_experts = self.beam_search(grid_scores, self.k_best)
         # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
-        expert_logits = self._score_experts(grid_scores, chosen_experts)
+        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
 
         expert_inputs = ((input, *args), kwargs)
         input_schema = nested_map(lambda x: None, expert_inputs)
@@ -146,10 +146,16 @@ class RemoteMixtureOfExperts(nn.Module):
 
         return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
 
-    def _score_experts(self, grid_scores: List[torch.Tensor],
-                       experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
-        flat_experts = [expert for row in experts for expert in row]
-        flat_batch_indices = torch.tensor([i for i, row in enumerate(experts) for uid in range(len(row))])
+    def compute_expert_scores(self, grid_scores: List[torch.Tensor],
+                              batch_experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
+        flat_experts = [expert for row in batch_experts for expert in row]
+        expert_counts = list(map(len, batch_experts))
+        max_num_experts = max(expert_counts)
+        total_num_experts = sum(expert_counts)
+        expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
+        flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
+        flat_local_indices = expert_index_in_batch - expert_strides[expert_row_ix]
 
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         for i, expert in enumerate(flat_experts):
@@ -162,12 +168,12 @@ class RemoteMixtureOfExperts(nn.Module):
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
-        output_dicts = [dict() for _ in range(len(experts))]
-        for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
-                                          flat_experts, flat_scores):
-            output_dicts[batch_i][expert] = score
 
-        return output_dicts
+        batch_size = len(batch_experts)
+        max_num_experts = max(map(len, batch_experts))
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
+        scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
+        return scores
 
 
 class _RemoteMoECall(torch.autograd.Function):