123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- from __future__ import annotations
- import time
- from concurrent.futures import Future
- from queue import Empty, Queue
- from typing import Any, Dict, List, Optional, Tuple
- import torch
- import torch.nn as nn
- from torch.autograd.function import once_differentiable
- from hivemind.compression import serialize_torch_tensor
- from hivemind.dht import DHT
- from hivemind.moe.client.beam_search import MoEBeamSearcher
- from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_server_stub
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
- from hivemind.moe.expert_uid import UID_DELIMITER
- from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
- from hivemind.utils import nested_flatten, nested_map, nested_pack
- from hivemind.utils.logging import get_logger
- logger = get_logger(__name__)
- class RemoteMixtureOfExperts(nn.Module):
- """
- A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
- Natively supports pytorch autograd.
- :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
- forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
- the missing experts
- :param in_features: common input size for experts and gating function
- :param grid_size: dimensions that form expert uid (see below)
- :param uid_prefix: common prefix for all expert uids (must end with '.')
- :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
- :param dht: a DHT instance used to search for best experts
- :param k_best: average this many highest-scoring experts to compute activations
- :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
- :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
- Any expert that didn't manage to return output after that delay is considered unavailable
- :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
- :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
- """
- def __init__(
- self,
- *,
- in_features,
- grid_size: Tuple[int, ...],
- dht: DHT,
- uid_prefix: str,
- k_best: int,
- k_min: int = 1,
- forward_timeout: Optional[float] = None,
- timeout_after_k_min: Optional[float] = None,
- backward_k_min: int = 1,
- backward_timeout: Optional[float] = None,
- detect_anomalies: bool = False,
- allow_zero_outputs: bool = False,
- **dht_kwargs,
- ):
- super().__init__()
- self.dht = dht
- self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
- self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
- self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
- self.timeout_after_k_min = timeout_after_k_min
- self.detect_anomalies = detect_anomalies
- self.allow_zero_outputs = allow_zero_outputs
- # jointly predict logits for all grid dimensions
- self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
- self._expert_info = None # expert['info'] from one of experts in the grid
- def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
- """
- Choose k best experts with beam search, then call chosen experts and average their outputs.
- Input tensor is averaged over all dimensions except for first and last
- (we assume that extra dimensions represent sequence length or image height/width)
- :param input: a tensor of values that are used to estimate gating function, batch-first.
- :param args: extra positional parameters that will be passed to each expert after input, batch-first
- :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
- :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
- """
- if input.ndim != 2:
- input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
- else:
- input_for_gating = input
- # 1. compute scores and find most appropriate experts with beam search
- grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
- chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
- [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
- )
- if self._expert_info is None:
- try:
- self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
- except StopIteration:
- raise RuntimeError(
- "No responding experts found during beam search. Check that UID prefixes and "
- "the grid size are consistent with running Server instances."
- )
- except P2PDaemonError as e:
- logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
- expert_mask, *expert_outputs = _RemoteCallMany.apply(
- DUMMY,
- chosen_experts,
- self.k_min,
- self.backward_k_min,
- self.timeout_after_k_min,
- self.forward_timeout,
- self.backward_timeout,
- self.detect_anomalies,
- self.allow_zero_outputs,
- self.info,
- *nested_flatten(((input, *args), kwargs)),
- )
- # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
- expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
- masked_logits = torch.full((1,), float("-inf"), device=expert_logits.device, dtype=expert_logits.dtype)
- expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
- expert_weights = torch.softmax(expert_logits, dim=1)
- averaged_outputs_flat = [
- (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
- for tensor in expert_outputs
- ] # ^-- multiply by softmax weights along first 2 axes
- return nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
- def compute_expert_scores(
- self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
- ) -> torch.Tensor:
- """
- Compute scores for each expert by adding up grid scores, autograd-friendly
- :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
- :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
- :returns: a tensor of scores, float32[batch_size, k]
- :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
- """
- expert_counts = list(map(len, batch_experts))
- batch_size = len(batch_experts)
- max_num_experts = max(expert_counts)
- total_num_experts = sum(expert_counts)
- device = grid_scores[0].device
- expert_index_in_batch = torch.arange(total_num_experts, device=device)
- expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
- flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
- flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
- flat_experts = [expert for row in batch_experts for expert in row]
- grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
- for i, expert in enumerate(flat_experts):
- expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
- expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
- grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
- scores_per_dim = [
- dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
- for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)
- ]
- flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
- scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
- scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
- return scores
- @property
- def info(self):
- if self._expert_info is None:
- # grab some expert to set ensemble output shape
- proj_device = self.proj.weight.device
- dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
- dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
- dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
- self._expert_info = dummy_experts[0].info
- return self._expert_info
- class _RemoteCallMany(torch.autograd.Function):
- """
- Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
- This function that can recover from individual failures during forward and/or backward pass as long as at least
- one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
- Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
- experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
- """
- @classmethod
- def forward(
- cls,
- ctx,
- dummy,
- experts_per_sample: List[List[RemoteExpert]],
- k_min: int,
- backward_k_min: int,
- timeout_after_k_min: float,
- forward_timeout: Optional[float],
- backward_timeout: Optional[float],
- detect_anomalies: bool,
- allow_zero_outputs: bool,
- info: Dict[str, Any],
- *flat_inputs: torch.Tensor,
- ) -> Tuple[torch.Tensor]:
- assert not torch.is_grad_enabled()
- num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
- flat_inputs_cpu = []
- for tensor in flat_inputs:
- if detect_anomalies and not tensor.isfinite().all():
- raise ValueError("One of inputs has nan/inf values")
- flat_inputs_cpu.append(tensor.cpu())
- flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
- assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
- # dispatch tasks to all remote experts collect responses
- pending_tasks: Dict[Future, Tuple[int, int]] = {}
- for i in range(num_samples):
- for j, expert in enumerate(experts_per_sample[i]):
- stub = get_server_stub(expert.p2p, expert.peer_id)
- serialized_tensors = (
- serialize_torch_tensor(tensor, proto.compression)
- for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
- )
- new_task = RemoteExpertWorker.run_coroutine(
- expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
- return_future=True,
- )
- pending_tasks[new_task] = (i, j)
- responded_inds, alive_flat_outputs = cls._collect_responses(
- pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
- )
- if len(responded_inds) < k_min:
- raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout")
- if not isinstance(info["outputs_schema"], tuple):
- outputs_schema = (info["outputs_schema"],)
- else:
- outputs_schema = info["outputs_schema"]
- outputs = nested_map(
- lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
- outputs_schema,
- )
- # assemble responses
- if len(responded_inds) > 0 or allow_zero_outputs:
- batch_inds, expert_inds = map(
- lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
- list(zip(*responded_inds)) or ([], []),
- )
- alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
- # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
- for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
- output[batch_inds, expert_inds] = response_stacked.to(output.device)
- else:
- raise RuntimeError("Forward pass: 0 experts responded within timeout and allow_zero_outputs is False")
- mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
- mask[batch_inds, expert_inds] = True
- # save individual outputs for backward pass
- ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
- ctx._saved_non_tensors = (
- info,
- backward_k_min,
- backward_timeout,
- timeout_after_k_min,
- experts_per_sample,
- detect_anomalies,
- )
- return (mask,) + outputs
- @classmethod
- @once_differentiable
- def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
- assert not torch.is_grad_enabled()
- (
- info,
- backward_k_min,
- backward_timeout,
- timeout_after_k_min,
- expert_per_sample,
- detect_anomalies,
- ) = ctx._saved_non_tensors
- alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
- dummy_grad_mask, *flat_grad_outputs = raw_grads
- flat_grad_outputs_cpu = []
- for tensor in flat_grad_outputs:
- if detect_anomalies and not tensor.isfinite().all():
- raise ValueError("One of gradients has nan/inf values")
- flat_grad_outputs_cpu.append(tensor.cpu())
- num_samples, max_experts = dummy_grad_mask.shape
- inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
- grad_outputs_per_expert = zip(
- *(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
- )
- backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
- # dispatch tasks to all remote experts, collect responses
- pending_tasks = {}
- for i, j, inputs_ij, grad_outputs_ij in zip(
- alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
- ):
- expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
- stub = get_server_stub(expert.p2p, expert.peer_id)
- inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
- serialized_tensors = (
- serialize_torch_tensor(tensor, proto.compression)
- for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
- )
- new_task = RemoteExpertWorker.run_coroutine(
- expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
- )
- pending_tasks[new_task] = (i, j)
- survivor_inds, survivor_grad_inputs = cls._collect_responses(
- pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
- )
- if len(survivor_inds) < backward_k_min:
- raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout")
- # assemble responses
- batch_inds, expert_inds = map(
- lambda x: torch.as_tensor(x, dtype=torch.long), list(zip(*survivor_inds)) or ([], [])
- )
- survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
- # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
- grad_inputs = nested_map(
- lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
- list(nested_flatten(info["forward_schema"])),
- )
- for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
- grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
- (num_samples, max_experts, *grad_input.shape[1:]),
- device=survivor_grad_stacked.device,
- dtype=survivor_grad_stacked.dtype,
- )
- grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
- grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
- return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
- @staticmethod
- def _collect_responses(
- task_to_indices: Dict[Future, Tuple[int, int]],
- num_samples: int,
- k_min: int,
- timeout_total: Optional[float],
- timeout_after_k_min: Optional[float],
- detect_anomalies: bool,
- ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
- """await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers"""
- timeout_total = float("inf") if timeout_total is None else timeout_total
- timeout_after_k_min = float("inf") if timeout_after_k_min is None else timeout_after_k_min
- num_successful_tasks = [0 for _ in range(num_samples)]
- pending_samples = num_samples # samples for which we have less than k_min results
- finished_indices, finished_outputs = [], []
- t_finish = time.perf_counter() + timeout_total
- pending_tasks = set(task_to_indices.keys())
- finished_tasks = Queue()
- try:
- # the algorithm below is essentially futures.as_completed, but for grpc.Future
- for task in pending_tasks:
- task.add_done_callback(finished_tasks.put)
- for _ in range(len(task_to_indices)):
- timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float("inf") else None
- task = finished_tasks.get(timeout=timeout)
- pending_tasks.discard(task)
- task_output = _process_dispatched_task(task, detect_anomalies)
- if task_output is not None:
- finished_indices.append(task_to_indices[task])
- finished_outputs.append(task_output)
- # count how many successes we have for each input sample
- sample_index = task_to_indices[task][0]
- num_successful_tasks[sample_index] += 1
- if num_successful_tasks[sample_index] == k_min:
- pending_samples -= 1
- if (
- pending_samples <= 0
- ): # all tasks finished, await stragglers for at most timeout_after_k_min
- t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
- except Empty:
- pass # we reached t_finish, this is normal behavior
- finally:
- for task in pending_tasks:
- task.cancel()
- return finished_indices, finished_outputs
- def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
- if task.exception() or task.cancelled():
- logger.warning(f"Task {task} failed: {type(task.exception())}")
- return None
- outputs = task.result()
- for tensor in outputs:
- if detect_anomalies and not tensor.isfinite().all():
- logger.error(f"Task {task} failed: output tensor contains nan/inf values")
- return None
- return outputs
|