|
@@ -61,7 +61,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param kwargs: extra keyword parameters that will be passed to each expert, batch-first
|
|
|
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
|
|
|
"""
|
|
|
- if self.allow_broadcasting and input.shape != 2:
|
|
|
+ if self.allow_broadcasting and input.ndim != 2:
|
|
|
# flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
|
|
|
flattened_dims = input.shape[:-1]
|
|
|
input_flat = input.view(-1, input.shape[-1])
|
|
@@ -75,7 +75,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
chosen_experts = self.beam_search(grid_scores, self.k_best)
|
|
|
# ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
|
|
|
- expert_logits = self._score_experts(grid_scores, chosen_experts)
|
|
|
+ expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
|
|
|
|
expert_inputs = ((input, *args), kwargs)
|
|
|
input_schema = nested_map(lambda x: None, expert_inputs)
|
|
@@ -146,10 +146,16 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
|
|
|
|
|
|
- def _score_experts(self, grid_scores: List[torch.Tensor],
|
|
|
- experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
|
- flat_experts = [expert for row in experts for expert in row]
|
|
|
- flat_batch_indices = torch.tensor([i for i, row in enumerate(experts) for uid in range(len(row))])
|
|
|
+ def compute_expert_scores(self, grid_scores: List[torch.Tensor],
|
|
|
+ batch_experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
|
+ flat_experts = [expert for row in batch_experts for expert in row]
|
|
|
+ expert_counts = list(map(len, batch_experts))
|
|
|
+ max_num_experts = max(expert_counts)
|
|
|
+ total_num_experts = sum(expert_counts)
|
|
|
+ expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
|
|
|
+ expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
|
|
|
+ flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
|
|
|
+ flat_local_indices = expert_index_in_batch - expert_strides[expert_row_ix]
|
|
|
|
|
|
grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
for i, expert in enumerate(flat_experts):
|
|
@@ -162,12 +168,12 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
|
|
|
flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
|
|
|
|
|
|
- output_dicts = [dict() for _ in range(len(experts))]
|
|
|
- for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
|
|
|
- flat_experts, flat_scores):
|
|
|
- output_dicts[batch_i][expert] = score
|
|
|
|
|
|
- return output_dicts
|
|
|
+ batch_size = len(batch_experts)
|
|
|
+ max_num_experts = max(map(len, batch_experts))
|
|
|
+ scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
|
|
|
+ scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
|
|
|
+ return scores
|
|
|
|
|
|
|
|
|
class _RemoteMoECall(torch.autograd.Function):
|