|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import time
|
|
|
-from typing import Tuple, List, Optional, Awaitable, Set, Dict
|
|
|
+from typing import Tuple, List, Optional, Awaitable, Set, Dict, Any
|
|
|
|
|
|
import grpc.experimental.aio
|
|
|
import torch
|
|
@@ -58,7 +58,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
self.allow_broadcasting = allow_broadcasting
|
|
|
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
- self._outputs_schema = None # expert['info'][outputs_schema] from one of experts in the grid
|
|
|
+ self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
|
|
|
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|
|
@@ -88,8 +88,8 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
# ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
|
|
|
expert_mask, *expert_outputs = _RemoteCallMany.apply(
|
|
|
- DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
|
|
|
- self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
|
|
|
+ DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
|
|
|
+ self.backward_timeout, self.loop, self.info, *nested_flatten(((input, *args), kwargs)))
|
|
|
# ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
|
|
|
|
|
|
expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
@@ -99,7 +99,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
averaged_outputs_flat = [
|
|
|
(expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
|
|
|
for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
|
|
|
- return nested_pack(averaged_outputs_flat, self.outputs_schema)
|
|
|
+ return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
|
|
|
|
|
|
async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
|
|
|
"""
|
|
@@ -139,9 +139,9 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
|
|
|
beam_experts = list(best_alive_prefixes.values())
|
|
|
|
|
|
- if self._outputs_schema is None:
|
|
|
+ if self._expert_info is None:
|
|
|
try:
|
|
|
- self._outputs_schema = beam_experts[0].info['outputs_schema']
|
|
|
+ self._expert_info = beam_experts[0].info
|
|
|
except grpc.RpcError as e:
|
|
|
logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
|
|
|
|
|
@@ -182,15 +182,15 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
return scores
|
|
|
|
|
|
@property
|
|
|
- def outputs_schema(self):
|
|
|
- if self._outputs_schema is None:
|
|
|
+ def info(self):
|
|
|
+ if self._expert_info is None:
|
|
|
# grab some expert to set ensemble output shape
|
|
|
proj_device = self.proj.weight.device
|
|
|
dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
|
|
|
dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
|
|
|
dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
|
|
|
- self._outputs_schema = dummy_experts[0].info['outputs_schema']
|
|
|
- return self._outputs_schema
|
|
|
+ self._expert_info = dummy_experts[0].info
|
|
|
+ return self._expert_info
|
|
|
|
|
|
|
|
|
class _RemoteCallMany(torch.autograd.Function):
|
|
@@ -206,7 +206,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
@classmethod
|
|
|
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
|
|
|
timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
|
|
|
- loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
+ loop: asyncio.base_events.BaseEventLoop, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
|
|
|
flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
|
|
@@ -215,7 +215,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
async def _forward():
|
|
|
# dispatch tasks to all remote experts, await responses
|
|
|
pending_tasks = {
|
|
|
- asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
|
|
|
+ asyncio.create_task(cls._forward_one_expert((i, j), expert, info, flat_inputs_per_sample[i]))
|
|
|
for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
|
|
|
}
|
|
|
alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
|
|
@@ -239,7 +239,8 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
|
|
|
# save individual outputs for backward pass
|
|
|
ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
|
|
|
- ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
|
|
|
+ ctx._saved_non_tensors = loop, info, backward_k_min, backward_timeout,\
|
|
|
+ timeout_after_k_min, experts_per_sample
|
|
|
return (mask,) + tuple(outputs)
|
|
|
|
|
|
return loop.run_until_complete(_forward())
|
|
@@ -248,7 +249,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
- loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
|
|
|
+ loop, info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
|
|
|
alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
|
|
|
dummy_grad_mask, *flat_grad_outputs = raw_grads
|
|
|
num_samples, max_experts = dummy_grad_mask.shape
|
|
@@ -261,8 +262,8 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
pending_tasks = set()
|
|
|
for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
|
|
|
inputs_per_expert, grad_outputs_per_expert):
|
|
|
- pending_tasks.add(asyncio.create_task(
|
|
|
- cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)))
|
|
|
+ pending_tasks.add(asyncio.create_task(cls._backward_one_expert(
|
|
|
+ (i, j), expert_per_sample[i.item()][j.item()], info, inputs_ij, grad_outputs_ij)))
|
|
|
|
|
|
backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
|
|
|
pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
|
|
@@ -281,28 +282,32 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
|
|
|
grad_inputs.append(grad_input_per_expert.sum(dim=1)) # add up gradients from each expert
|
|
|
|
|
|
- return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
+ return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
|
|
|
return loop.run_until_complete(_backward())
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
|
|
|
+ async def _forward_one_expert(
|
|
|
+ grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any], inputs: Tuple[torch.Tensor]):
|
|
|
stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
|
|
|
try:
|
|
|
outputs = await stub.forward(runtime_pb2.ExpertRequest(
|
|
|
- uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
|
|
|
+ uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression) for tensor, proto in
|
|
|
+ zip(inputs, nested_flatten(info['forward_schema']))]))
|
|
|
return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
|
|
|
except grpc.experimental.aio.AioRpcError as error:
|
|
|
logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
|
|
|
+ async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any],
|
|
|
inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
|
|
|
stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
|
|
|
- payload = tuple(nested_flatten((inputs, grad_outputs)))
|
|
|
+ inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs)))
|
|
|
+ backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
|
|
|
try:
|
|
|
grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
|
|
|
- uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
|
|
|
+ uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]))
|
|
|
return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
|
|
|
except grpc.experimental.aio.AioRpcError as error:
|
|
|
logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
|