|
@@ -147,7 +147,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
def compute_expert_scores(
|
|
|
self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
|
|
|
- """ TODO docstring here """
|
|
|
+ """ TODO(jheuristic) docstring here """
|
|
|
expert_counts = list(map(len, batch_experts))
|
|
|
batch_size = len(batch_experts)
|
|
|
max_num_experts = max(expert_counts)
|