소스 검색

vectorize compute_scores better

justheuristic 5 년 전
부모
커밋
b4ee0ec5ab
1개의 변경된 파일3개의 추가작업 그리고 4개의 파일을 삭제
  1. 3 4
      tesseract/client/moe.py

+ 3 - 4
tesseract/client/moe.py

@@ -148,14 +148,15 @@ class RemoteMixtureOfExperts(nn.Module):
 
     def compute_expert_scores(self, grid_scores: List[torch.Tensor],
                               batch_experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
-        flat_experts = [expert for row in batch_experts for expert in row]
         expert_counts = list(map(len, batch_experts))
+        batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
         expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
         expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
-        flat_local_indices = expert_index_in_batch - expert_strides[expert_row_ix]
+        flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
+        flat_experts = [expert for row in batch_experts for expert in row]
 
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         for i, expert in enumerate(flat_experts):
@@ -169,8 +170,6 @@ class RemoteMixtureOfExperts(nn.Module):
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
 
-        batch_size = len(batch_experts)
-        max_num_experts = max(map(len, batch_experts))
         scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores