|
@@ -47,8 +47,10 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
super().__init__()
|
|
|
self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
|
|
|
self.loop = loop or asyncio.new_event_loop()
|
|
|
+ # fmt:off
|
|
|
assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
|
|
|
"(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
|
|
|
+ # fmt:on
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
self.timeout_after_k_min = timeout_after_k_min
|
|
@@ -59,27 +61,25 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|
|
|
- Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
|
- :param input: a tensor of values that are used to estimate gating function, batch-first
|
|
|
+ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all
|
|
|
+ dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions)
|
|
|
+
|
|
|
+ :param input: a tensor of values that are used to estimate gating function, batch-first.
|
|
|
:param args: extra positional parameters that will be passed to each expert after input, batch-first
|
|
|
:param kwargs: extra keyword parameters that will be passed to each expert, batch-first
|
|
|
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
|
|
|
"""
|
|
|
- if self.allow_broadcasting and input.ndim != 2:
|
|
|
- # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
|
|
|
- flattened_dims = input.shape[:-1]
|
|
|
- input_flat = input.view(-1, input.shape[-1])
|
|
|
- args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
|
|
|
- kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
|
|
|
- out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
|
|
|
- return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
|
|
|
+ if input.ndim != 2:
|
|
|
+ input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
|
|
|
+ else:
|
|
|
+ input_for_gating = input
|
|
|
|
|
|
# 1. compute scores and find most appropriate experts with beam search
|
|
|
- grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
|
|
|
+ grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
|
|
|
|
|
|
async def _search():
|
|
|
coroutines = [asyncio.create_task(self.beam_search(
|
|
|
- [dim_scores[i] for dim_scores in grid_scores], self.k_best))
|
|
|
+ [dim_scores[i] for dim_scores in grid_scores], self.k_best), name=f'beam_search_{i}')
|
|
|
for i in range(len(input))]
|
|
|
return list(await asyncio.gather(*coroutines))
|
|
|
|
|
@@ -184,7 +184,9 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
def outputs_schema(self):
|
|
|
if self._outputs_schema is None:
|
|
|
# grab some expert to set ensemble output shape
|
|
|
- dummy_scores = self.proj(torch.randn(self.proj.in_features)).cpu().split_with_sizes(self.grid_size, dim=-1)
|
|
|
+ proj_device = self.proj.weight.device
|
|
|
+ dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
|
|
|
+ dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
|
|
|
dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
|
|
|
self._outputs_schema = dummy_experts[0].info['outputs_schema']
|
|
|
return self._outputs_schema
|
|
@@ -212,7 +214,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
async def _forward():
|
|
|
# dispatch tasks to all remote experts, await responses
|
|
|
pending_tasks = {
|
|
|
- asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
|
|
|
+ asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]), name=f'forward_expert_{j}_for_{i}')
|
|
|
for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
|
|
|
}
|
|
|
alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
|
|
@@ -259,8 +261,8 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
|
|
|
inputs_per_expert, grad_outputs_per_expert):
|
|
|
pending_tasks.add(asyncio.create_task(
|
|
|
- cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)
|
|
|
- ))
|
|
|
+ cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij),
|
|
|
+ name=f'backward_expert_{j}_for_{i}'))
|
|
|
|
|
|
backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
|
|
|
pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
|
|
@@ -280,6 +282,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
grad_inputs.append(grad_input_per_expert.sum(dim=1)) # add up gradients from each expert
|
|
|
|
|
|
return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
+
|
|
|
return loop.run_until_complete(_backward())
|
|
|
|
|
|
@staticmethod
|
|
@@ -308,7 +311,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
async def _wait_for_responses(
|
|
|
pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
|
|
|
num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
|
|
|
- ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
|
|
|
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
|
|
|
""" await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
|
|
|
timeout_total = float('inf') if timeout_total is None else timeout_total
|
|
|
timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
|