|
@@ -14,7 +14,7 @@ from hivemind.client.beam_search import MoEBeamSearcher
|
|
|
from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
|
|
|
from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
|
|
|
from hivemind.server.expert_uid import UID_DELIMITER
|
|
|
-from hivemind.utils import nested_pack, nested_flatten
|
|
|
+from hivemind.utils import nested_pack, nested_flatten, nested_map
|
|
|
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
@@ -23,7 +23,7 @@ logger = get_logger(__name__)
|
|
|
|
|
|
class RemoteMixtureOfExperts(nn.Module):
|
|
|
"""
|
|
|
- A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
|
|
|
+ A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
|
|
|
Natively supports pytorch autograd.
|
|
|
|
|
|
:note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
|
|
@@ -38,14 +38,15 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param k_best: average this many highest-scoring experts to compute activations
|
|
|
:param k_min: make sure at least this many experts returned output (i.e. didn't fail)
|
|
|
:param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
|
|
|
- :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
|
|
|
Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
+ :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
|
|
|
+ :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
|
|
|
k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
|
|
|
backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
|
|
|
- **dht_kwargs):
|
|
|
+ allow_zero_outputs: bool = False, **dht_kwargs):
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
|
|
@@ -53,8 +54,10 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
self.timeout_after_k_min = timeout_after_k_min
|
|
|
self.detect_anomalies = detect_anomalies
|
|
|
+ self.allow_zero_outputs = allow_zero_outputs
|
|
|
|
|
|
- self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
+ # jointly predict logits for all grid dimensions
|
|
|
+ self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
|
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
|
|
|
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
@@ -87,7 +90,8 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
expert_mask, *expert_outputs = _RemoteCallMany.apply(
|
|
|
DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
|
|
|
- self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
|
|
|
+ self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
|
|
|
+ *nested_flatten(((input, *args), kwargs)))
|
|
|
# ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
|
|
|
|
|
|
expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
@@ -97,6 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
averaged_outputs_flat = [
|
|
|
(expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
|
|
|
for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
|
|
|
+
|
|
|
return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
|
|
|
|
|
|
def compute_expert_scores(
|
|
@@ -152,13 +157,14 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
|
|
|
|
|
|
Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
|
|
|
- experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
|
|
|
+ experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
|
|
|
"""
|
|
|
|
|
|
@classmethod
|
|
|
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
|
|
|
timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
|
|
|
- detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
+ detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any],
|
|
|
+ *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
|
|
|
|
|
@@ -181,32 +187,42 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
|
|
|
pending_tasks[new_task] = (i, j)
|
|
|
|
|
|
- alive_grid_indices, alive_flat_outputs = cls._collect_responses(
|
|
|
+ responded_inds, alive_flat_outputs = cls._collect_responses(
|
|
|
pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
|
|
|
- if len(alive_grid_indices) == 0:
|
|
|
- raise TimeoutError("Forward pass: no alive experts responded within timeout.")
|
|
|
+ if len(responded_inds) < k_min:
|
|
|
+ raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
|
|
|
+
|
|
|
+ if not isinstance(info['outputs_schema'], tuple):
|
|
|
+ outputs_schema = (info['outputs_schema'],)
|
|
|
+ else:
|
|
|
+ outputs_schema = info['outputs_schema']
|
|
|
+ outputs = nested_map(
|
|
|
+ lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
|
|
|
+ outputs_schema)
|
|
|
|
|
|
# assemble responses
|
|
|
- alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
|
|
|
- mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
|
|
|
- mask[alive_ii, alive_jj] = True
|
|
|
+ if len(responded_inds) > 0 or allow_zero_outputs:
|
|
|
+ batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
|
|
|
+ list(zip(*responded_inds)) or ([], []))
|
|
|
|
|
|
- alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
|
|
|
- # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
|
|
|
+ alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
|
|
|
+ # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
|
|
|
|
|
|
- outputs = []
|
|
|
- for response_stacked in alive_flat_outputs_stacked:
|
|
|
- output = torch.zeros(
|
|
|
- [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
|
|
|
- dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
|
|
|
- output[alive_ii, alive_jj] = response_stacked
|
|
|
- outputs.append(output.to(flat_inputs[0].device))
|
|
|
+ for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
|
|
|
+ output[batch_inds, expert_inds] = response_stacked.to(output.device)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise RuntimeError('Forward pass: 0 experts responded within timeout and allow_zero_outputs is False')
|
|
|
+
|
|
|
+ mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
|
|
|
+ mask[batch_inds, expert_inds] = True
|
|
|
|
|
|
# save individual outputs for backward pass
|
|
|
- ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
|
|
|
+ ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
|
|
|
ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
|
|
|
detect_anomalies)
|
|
|
- return (mask,) + tuple(outputs)
|
|
|
+
|
|
|
+ return (mask,) + outputs
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
@@ -235,35 +251,37 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
|
|
|
inputs_per_expert, grad_outputs_per_expert):
|
|
|
expert = expert_per_sample[i.item()][j.item()]
|
|
|
- stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
|
|
|
+ stub = _get_expert_stub(expert.endpoint)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
|
|
|
tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
|
|
|
new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
|
|
|
pending_tasks[new_task] = (i, j)
|
|
|
|
|
|
- backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
|
|
|
+ survivor_inds, survivor_grad_inputs = cls._collect_responses(
|
|
|
pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
|
|
|
- if len(backward_survivor_indices) == 0:
|
|
|
- raise TimeoutError("Backward pass: no alive experts responded within timeout.")
|
|
|
+ if len(survivor_inds) < backward_k_min:
|
|
|
+ raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
|
|
|
|
|
|
# assemble responses
|
|
|
- backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
|
|
|
+ batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, dtype=torch.long),
|
|
|
+ list(zip(*survivor_inds)) or ([], []))
|
|
|
|
|
|
survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
|
|
|
# torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
|
|
|
|
|
|
- grad_inputs = []
|
|
|
- for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
|
|
|
+ grad_inputs = nested_map(
|
|
|
+ lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
|
|
|
+ list(nested_flatten(info['forward_schema'])))
|
|
|
+
|
|
|
+ for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
|
|
|
grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
|
|
|
- (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
|
|
|
+ (num_samples, max_experts, *grad_input.shape[1:]),
|
|
|
device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
|
|
|
- grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
|
|
|
-
|
|
|
- # sum gradients from each expert
|
|
|
- grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
|
|
|
+ grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
|
|
|
+ grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
|
|
|
|
|
|
- return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
+ return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
|
|
|
@staticmethod
|
|
|
def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
|