|
@@ -80,7 +80,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
async def _search():
|
|
|
coroutines = [asyncio.create_task(self.beam_search(
|
|
|
- [dim_scores[i] for dim_scores in grid_scores], self.k_best), name=f'beam_search_{i}')
|
|
|
+ [dim_scores[i] for dim_scores in grid_scores], self.k_best))
|
|
|
for i in range(len(input))]
|
|
|
return list(await asyncio.gather(*coroutines))
|
|
|
|
|
@@ -215,7 +215,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
async def _forward():
|
|
|
# dispatch tasks to all remote experts, await responses
|
|
|
pending_tasks = {
|
|
|
- asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]), name=f'forward_expert_{j}_for_{i}')
|
|
|
+ asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
|
|
|
for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
|
|
|
}
|
|
|
alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
|
|
@@ -262,8 +262,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
|
|
|
inputs_per_expert, grad_outputs_per_expert):
|
|
|
pending_tasks.add(asyncio.create_task(
|
|
|
- cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij),
|
|
|
- name=f'backward_expert_{j}_for_{i}'))
|
|
|
+ cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)))
|
|
|
|
|
|
backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
|
|
|
pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
|