|
@@ -1,14 +1,20 @@
|
|
|
-from functools import partial
|
|
|
-from typing import Tuple, List, Optional
|
|
|
+from __future__ import annotations
|
|
|
+import time
|
|
|
+import asyncio
|
|
|
+from typing import Tuple, List, Optional, Awaitable, Set, Dict
|
|
|
|
|
|
-import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
+import grpc.experimental.aio
|
|
|
|
|
|
-from hivemind.client.expert import RemoteExpert, _RemoteModuleCall, DUMMY
|
|
|
-from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, run_in_background, \
|
|
|
- run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
|
|
|
+import hivemind
|
|
|
+from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
|
|
|
+from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
|
|
|
+ serialize_torch_tensor, deserialize_torch_tensor
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
class RemoteMixtureOfExperts(nn.Module):
|
|
@@ -25,30 +31,31 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param uid_prefix: common prefix for all expert uids
|
|
|
expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
:param dht: DHT where the experts reside
|
|
|
- :param num_workers: number of threads for parallel dht operation
|
|
|
:param k_best: queries this many experts with highest scores
|
|
|
:param k_min: makes sure at least this many experts returned output
|
|
|
:param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
- :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
:param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
|
|
|
allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
|
|
|
allow_broadcasting=False will raise an error
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
|
|
|
- forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
- uid_prefix='', expert_padding=None, allow_broadcasting=True):
|
|
|
+ def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, k_best: int, k_min: int = 1,
|
|
|
+ forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
|
|
|
+ backward_k_min: int = 1, backward_timeout: Optional[float] = None, uid_prefix='',
|
|
|
+ allow_broadcasting=True, loop: asyncio.BaseEventLoop = None):
|
|
|
super().__init__()
|
|
|
- self.dht, self.grid_size = dht, grid_size
|
|
|
- self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
+ self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
|
|
|
+ self.loop = loop or asyncio.new_event_loop()
|
|
|
+ assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
|
|
|
+ "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
self.timeout_after_k_min = timeout_after_k_min
|
|
|
self.allow_broadcasting = allow_broadcasting
|
|
|
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
- self._outputs_schema = None
|
|
|
+ self._outputs_schema = None # expert['info'][outputs_schema] from one of experts in the grid
|
|
|
|
|
|
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|
|
@@ -69,25 +76,31 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
# 1. compute scores and find most appropriate experts with beam search
|
|
|
grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
|
|
|
- chosen_experts = self.beam_search(grid_scores, self.k_best)
|
|
|
- # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
|
|
|
- expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
|
+ async def _search():
|
|
|
+ coroutines = [asyncio.create_task(self.beam_search(
|
|
|
+ [dim_scores[i] for dim_scores in grid_scores], self.k_best))
|
|
|
+ for i in range(len(input))]
|
|
|
+ return list(await asyncio.gather(*coroutines))
|
|
|
|
|
|
- expert_inputs = ((input, *args), kwargs)
|
|
|
- input_schema = nested_map(lambda x: None, expert_inputs)
|
|
|
- flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
|
|
|
+ chosen_experts: List[List[RemoteExpert]] = self.loop.run_until_complete(_search())
|
|
|
+ # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
|
|
|
- batch_jobs_args = tuple(
|
|
|
- (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
|
|
|
- self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
|
|
|
- for i in range(len(input))
|
|
|
- )
|
|
|
+ expert_mask, *expert_outputs = _RemoteCallMany.apply(
|
|
|
+ DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
|
|
|
+ self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
|
|
|
+ # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
|
|
|
|
|
|
- averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
|
|
|
+ expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
|
+ masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
|
|
|
+ expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
|
|
|
+ expert_weights = torch.softmax(expert_logits, dim=1)
|
|
|
+ averaged_outputs_flat = [
|
|
|
+ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
|
|
|
+ for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
|
|
|
return nested_pack(averaged_outputs_flat, self.outputs_schema)
|
|
|
|
|
|
- def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|
|
|
+ async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
|
|
|
"""
|
|
|
Find and return k best experts in the grid using (exact) beam search of the product space
|
|
|
|
|
@@ -99,51 +112,39 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
RemoteExpert instances for *up to* k_best experts
|
|
|
"""
|
|
|
assert len(grid_scores) == len(self.grid_size)
|
|
|
- assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|
|
|
- batch_size = len(grid_scores[0])
|
|
|
- beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
|
|
|
- scores = np.zeros([batch_size, 1], dtype=np.float64)
|
|
|
+ assert all(dim_scores.shape == (self.grid_size[dim_index],) for dim_index, dim_scores in enumerate(grid_scores))
|
|
|
+ grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores]
|
|
|
|
|
|
- delimiters = np.array(self.dht.UID_DELIMITER)[None, None, None] # pre-compute numpy array for fast concat
|
|
|
+ beam_experts: List[RemoteExpert] = []
|
|
|
+ beam: List[str] = [self.uid_prefix]
|
|
|
+ beam_scores = torch.zeros(1)
|
|
|
|
|
|
for dim_index, dim_scores in enumerate(grid_scores):
|
|
|
- dim_scores = dim_scores.detach().cpu().numpy()
|
|
|
- assert dim_scores.shape[-1] == self.grid_size[dim_index]
|
|
|
-
|
|
|
- # create all possible successsors from current beam
|
|
|
- dim_indices = np.arange(dim_scores.shape[1]).astype(str)
|
|
|
- new_candidates = beam[:, :, None] + delimiters + dim_indices[None, None, :]
|
|
|
- new_candidates = new_candidates.reshape([batch_size, -1])
|
|
|
+ # create all possible successors from current beam and sort them by total score
|
|
|
+ expanded_scores = beam_scores[:, None] + dim_scores[None, :]
|
|
|
+ sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores))
|
|
|
+ for flat_i in (-expanded_scores).flatten().argsort().numpy()]
|
|
|
|
|
|
- new_scores = scores[:, :, None] + dim_scores[:, None, :]
|
|
|
- new_scores = new_scores.reshape([batch_size, -1])
|
|
|
+ sorted_candidates = [f"{beam[row]}{self.dht.UID_DELIMITER}{col}" for row, col in sorted_indices]
|
|
|
+ candidate_to_indices = dict(zip(sorted_candidates, sorted_indices))
|
|
|
|
|
|
# select k best candidates according to scores but only those that are still active
|
|
|
- new_order = np.argsort(- new_scores, axis=-1)
|
|
|
- top_alive_lookups = [
|
|
|
- run_in_background(self.dht.first_k_active, cands[order], k_best, **kwargs)
|
|
|
- for cands, order in zip(new_candidates, new_order)]
|
|
|
-
|
|
|
- batch_cand_to_score = [
|
|
|
- dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
|
|
|
-
|
|
|
- top_alive_prefixes = [result.result() for result in top_alive_lookups]
|
|
|
- top_alive_scores = [list(map(cand_to_score.get, top_cands))
|
|
|
- for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
|
|
|
-
|
|
|
- # pad up to beam size
|
|
|
- beam = np.array([row + [self.expert_padding] * (k_best - len(row))
|
|
|
- for row in top_alive_prefixes], dtype='object')
|
|
|
- scores = np.array([row + [-float('inf')] * (k_best - len(row))
|
|
|
- for row in top_alive_scores], dtype='float32')
|
|
|
-
|
|
|
- unique_experts = self.dht.get_experts(list(set(
|
|
|
- uid for row in beam for uid in row if uid != self.expert_padding)))
|
|
|
+ best_alive_prefixes: Dict[str, RemoteExpert] = await self.dht.first_k_active(
|
|
|
+ uid_prefixes=sorted_candidates, k=k_best, return_future=True, **kwargs)
|
|
|
+ if not best_alive_prefixes:
|
|
|
+ logger.warning(f"Grid is empty: found neither of {sorted_candidates}")
|
|
|
+ break
|
|
|
+ beam = list(best_alive_prefixes.keys())
|
|
|
+ beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
|
|
|
+ beam_experts = list(best_alive_prefixes.values())
|
|
|
+
|
|
|
if self._outputs_schema is None:
|
|
|
- self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
|
|
|
- unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
|
|
|
+ try:
|
|
|
+ self._outputs_schema = beam_experts[0].info['outputs_schema']
|
|
|
+ except grpc.RpcError as e:
|
|
|
+ logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
|
|
|
|
|
|
- return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
|
|
|
+ return beam_experts
|
|
|
|
|
|
def compute_expert_scores(
|
|
|
self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
|
|
@@ -164,11 +165,11 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
|
|
|
flat_experts = [expert for row in batch_experts for expert in row]
|
|
|
|
|
|
- grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
+ grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
|
|
|
for i, expert in enumerate(flat_experts):
|
|
|
expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
|
|
|
expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
|
|
|
- grid_indices[i] = expert_indices
|
|
|
+ grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
|
|
|
|
|
|
scores_per_dim = [
|
|
|
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
|
|
@@ -183,86 +184,156 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
def outputs_schema(self):
|
|
|
if self._outputs_schema is None:
|
|
|
# grab some expert to set ensemble output shape
|
|
|
- dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
|
|
|
- self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
|
|
|
+ dummy_scores = self.proj(torch.randn(self.proj.in_features)).cpu().split_with_sizes(self.grid_size, dim=-1)
|
|
|
+ dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
|
|
|
+ self._outputs_schema = dummy_experts[0].info['outputs_schema']
|
|
|
return self._outputs_schema
|
|
|
|
|
|
|
|
|
-class _RemoteMoECall(torch.autograd.Function):
|
|
|
+class _RemoteCallMany(torch.autograd.Function):
|
|
|
"""
|
|
|
- Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
|
|
|
- This function that can recover from individual failures during forward and/or backward passes.
|
|
|
- For user-friendly version of this function, use RemoteMixtureOfExperts module.
|
|
|
+ Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
|
|
|
+ This function that can recover from individual failures during forward and/or backward pass as long as at least
|
|
|
+ one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
|
|
|
+
|
|
|
+ Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
|
|
|
+ experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
|
|
|
"""
|
|
|
|
|
|
@classmethod
|
|
|
- def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
|
|
|
- k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
|
|
|
- backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
- expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
|
|
|
- assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
|
|
|
-
|
|
|
- # 1. call experts and await results
|
|
|
- jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
|
|
|
- results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
|
|
|
-
|
|
|
- alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
|
|
|
- if not isinstance(result, BaseException)])
|
|
|
- # ^ ^ ^-- a list of indices of experts that returned outputs in time
|
|
|
- # \ \-- list of outputs of every expert that didn't die on us
|
|
|
- # \-- a list of autograd contexts, used for parallel backward
|
|
|
-
|
|
|
- # 2. compute softmax weights for alive experts and average outputs
|
|
|
- alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
|
|
|
- alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
-
|
|
|
- stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
|
|
|
-
|
|
|
- flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
|
|
|
- for stacked_out in stacked_alive_outputs)
|
|
|
-
|
|
|
- # 3. save individual outputs for backward pass
|
|
|
- ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|
|
|
- ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
|
|
|
- return tuple(map(torch.Tensor.detach, flat_average_outputs))
|
|
|
+ def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
|
|
|
+ timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
|
|
|
+ loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
+ assert not torch.is_grad_enabled()
|
|
|
+ num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
|
|
|
+ flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
|
|
|
+ assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
|
|
|
+
|
|
|
+ async def _forward():
|
|
|
+ # dispatch tasks to all remote experts, await responses
|
|
|
+ pending_tasks = {
|
|
|
+ asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
|
|
|
+ for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
|
|
|
+ }
|
|
|
+ alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
|
|
|
+ pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
|
|
|
+
|
|
|
+ # assemble responses
|
|
|
+ alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
|
|
|
+ mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
|
|
|
+ mask[alive_ii, alive_jj] = True
|
|
|
+
|
|
|
+ alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
|
|
|
+ # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
|
|
|
+
|
|
|
+ outputs = []
|
|
|
+ for response_stacked in alive_flat_outputs_stacked:
|
|
|
+ output = torch.zeros(
|
|
|
+ [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
|
|
|
+ dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
|
|
|
+ output[alive_ii, alive_jj] = response_stacked
|
|
|
+ outputs.append(output)
|
|
|
+
|
|
|
+ # save individual outputs for backward pass
|
|
|
+ ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
|
|
|
+ ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
|
|
|
+ return (mask,) + tuple(outputs)
|
|
|
+
|
|
|
+ return loop.run_until_complete(_forward())
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
- def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
- """ Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
- expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
|
|
|
- alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
|
|
|
-
|
|
|
- jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
|
- for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
- results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
|
|
|
- backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
|
- backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
|
- backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
- survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
- weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
- flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
|
|
|
- for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
|
|
|
-
|
|
|
- # compute grad w.r.t. logits
|
|
|
- grad_wrt_probs = sum(tuple(
|
|
|
- torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
|
|
|
- dim=tuple(range(1, stacked_avive_out.ndim)))
|
|
|
- for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
|
|
|
- ))
|
|
|
- softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
|
|
|
- grad_wrt_survived_logits = grad_wrt_probs @ softmax_jacobian
|
|
|
- grad_wrt_logits = torch.zeros_like(expert_logits).scatter(0, backward_survivors_ix, grad_wrt_survived_logits)
|
|
|
-
|
|
|
- return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
|
|
|
+ def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
+ assert not torch.is_grad_enabled()
|
|
|
+ loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
|
|
|
+ alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
|
|
|
+ dummy_grad_mask, *flat_grad_outputs = raw_grads
|
|
|
+ num_samples, max_experts = dummy_grad_mask.shape
|
|
|
+
|
|
|
+ inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
|
|
|
+ grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
|
|
|
+
|
|
|
+ async def _backward():
|
|
|
+ # dispatch tasks to all remote experts, await responses
|
|
|
+ pending_tasks = set()
|
|
|
+ for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
|
|
|
+ inputs_per_expert, grad_outputs_per_expert):
|
|
|
+ pending_tasks.add(asyncio.create_task(
|
|
|
+ cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)
|
|
|
+ ))
|
|
|
+
|
|
|
+ backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
|
|
|
+ pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
|
|
|
+
|
|
|
+ # assemble responses
|
|
|
+ backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices))
|
|
|
+ survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
|
|
|
+ # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
|
|
|
+
|
|
|
+ grad_inputs = []
|
|
|
+ for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
|
|
|
+ grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
|
|
|
+ (num_samples, max_experts, *flat_inputs[i].shape[1:]),
|
|
|
+ device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
|
|
|
+ grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
|
|
|
+
|
|
|
+ grad_inputs.append(grad_input_per_expert.sum(dim=1)) # add up gradients from each expert
|
|
|
+
|
|
|
+ return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
+ return loop.run_until_complete(_backward())
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
|
|
|
+ stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
|
|
|
+ try:
|
|
|
+ outputs = await stub.forward(runtime_pb2.ExpertRequest(
|
|
|
+ uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
|
|
|
+ return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
|
|
|
+ except grpc.experimental.aio.AioRpcError as error:
|
|
|
+ logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
|
|
|
|
|
|
@staticmethod
|
|
|
- def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
- """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
|
|
|
- return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.stub, *nested_flatten((args, kwargs)))
|
|
|
+ async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
|
|
|
+ inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
|
|
|
+ stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
|
|
|
+ payload = tuple(nested_flatten((inputs, grad_outputs)))
|
|
|
+ try:
|
|
|
+ grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
|
|
|
+ uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
|
|
|
+ return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
|
|
|
+ except grpc.experimental.aio.AioRpcError as error:
|
|
|
+ logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
|
|
|
|
|
|
@staticmethod
|
|
|
- def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
|
|
|
- backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
- grad_dummy, no_grad_uid, no_grad_stub, *grad_inputs = backward_result
|
|
|
- return grad_inputs
|
|
|
+ async def _wait_for_responses(
|
|
|
+ pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
|
|
|
+ num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
|
|
|
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
|
|
|
+ """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
|
|
|
+ timeout_total = float('inf') if timeout_total is None else timeout_total
|
|
|
+ timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
|
|
|
+ num_successful_tasks = [0 for _ in range(num_samples)]
|
|
|
+ pending_samples = num_samples # samples for which we have less than k_min results
|
|
|
+ finished_indices, finished_outputs = [], []
|
|
|
+ t_finish = time.perf_counter() + timeout_total
|
|
|
+
|
|
|
+ while pending_tasks and time.perf_counter() <= t_finish:
|
|
|
+ finished_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED,
|
|
|
+ timeout=t_finish - time.perf_counter())
|
|
|
+ for task in finished_tasks:
|
|
|
+ if not task.result():
|
|
|
+ continue
|
|
|
+ task_indices, task_flat_outputs = await task
|
|
|
+ finished_indices.append(task_indices)
|
|
|
+ finished_outputs.append(task_flat_outputs)
|
|
|
+
|
|
|
+ sample_index = task_indices[0]
|
|
|
+ num_successful_tasks[sample_index] += 1
|
|
|
+ if num_successful_tasks[sample_index] == k_min:
|
|
|
+ pending_samples -= 1
|
|
|
+ if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
|
|
|
+ t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
|
|
|
+
|
|
|
+ for task in pending_tasks:
|
|
|
+ task.cancel()
|
|
|
+ return finished_indices, finished_outputs
|