|
@@ -36,12 +36,14 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param k_best: average this many highest-scoring experts to compute activations
|
|
|
:param k_min: make sure at least this many experts returned output (i.e. didn't fail)
|
|
|
:param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
|
|
|
+ :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
|
|
|
Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
|
|
|
k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
|
|
|
- backward_k_min: int = 1, backward_timeout: Optional[float] = None, **dht_kwargs):
|
|
|
+ backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
|
|
|
+ **dht_kwargs):
|
|
|
super().__init__()
|
|
|
if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
|
|
|
uid_prefix += hivemind.dht.UID_DELIMITER
|
|
@@ -51,6 +53,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
self.timeout_after_k_min = timeout_after_k_min
|
|
|
+ self.detect_anomalies = detect_anomalies
|
|
|
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
@@ -85,7 +88,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
expert_mask, *expert_outputs = _RemoteCallMany.apply(
|
|
|
DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
|
|
|
- self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs)))
|
|
|
+ self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
|
|
|
# ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
|
|
|
|
|
|
expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
|
|
@@ -156,11 +159,16 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
@classmethod
|
|
|
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
|
|
|
timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
|
|
|
- info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
+ detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
|
|
|
|
|
|
- flat_inputs_cpu = [tensor.cpu() for tensor in flat_inputs]
|
|
|
+ flat_inputs_cpu = []
|
|
|
+ for tensor in flat_inputs:
|
|
|
+ if detect_anomalies and not tensor.isfinite().all():
|
|
|
+ raise ValueError("One of inputs has nan/inf values")
|
|
|
+ flat_inputs_cpu.append(tensor.cpu())
|
|
|
+
|
|
|
flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
|
|
|
assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
|
|
|
|
|
@@ -175,7 +183,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
pending_tasks[new_task] = (i, j)
|
|
|
|
|
|
alive_grid_indices, alive_flat_outputs = cls._collect_responses(
|
|
|
- pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
|
|
|
+ pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
|
|
|
if len(alive_grid_indices) == 0:
|
|
|
raise TimeoutError("Forward pass: no alive experts responded within timeout.")
|
|
|
|
|
@@ -197,18 +205,25 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
|
|
|
# save individual outputs for backward pass
|
|
|
ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
|
|
|
- ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
|
|
|
+ ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
|
|
|
+ detect_anomalies)
|
|
|
return (mask,) + tuple(outputs)
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
- info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
|
|
|
+ (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
|
|
|
+ detect_anomalies) = ctx._saved_non_tensors
|
|
|
alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
|
|
|
|
|
|
dummy_grad_mask, *flat_grad_outputs = raw_grads
|
|
|
- flat_grad_outputs_cpu = [tensor.cpu() for tensor in flat_grad_outputs]
|
|
|
+
|
|
|
+ flat_grad_outputs_cpu = []
|
|
|
+ for tensor in flat_grad_outputs:
|
|
|
+ if detect_anomalies and not tensor.isfinite().all():
|
|
|
+ raise ValueError("One of gradients has nan/inf values")
|
|
|
+ flat_grad_outputs_cpu.append(tensor.cpu())
|
|
|
|
|
|
num_samples, max_experts = dummy_grad_mask.shape
|
|
|
|
|
@@ -229,7 +244,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
pending_tasks[new_task] = (i, j)
|
|
|
|
|
|
backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
|
|
|
- pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
|
|
|
+ pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
|
|
|
if len(backward_survivor_indices) == 0:
|
|
|
raise TimeoutError("Backward pass: no alive experts responded within timeout.")
|
|
|
|
|
@@ -249,11 +264,11 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
# sum gradients from each expert
|
|
|
grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
|
|
|
|
|
|
- return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
+ return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
|
|
|
@staticmethod
|
|
|
def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
|
|
|
- timeout_total: Optional[float], timeout_after_k_min: Optional[float]
|
|
|
+ timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
|
|
|
) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
|
|
|
""" await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
|
|
|
timeout_total = float('inf') if timeout_total is None else timeout_total
|
|
@@ -275,20 +290,18 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
task = finished_tasks.get(timeout=timeout)
|
|
|
pending_tasks.discard(task)
|
|
|
|
|
|
- if task.exception() or task.cancelled():
|
|
|
- logger.warning(f"Task {task} failed: {type(task.exception())}")
|
|
|
- continue
|
|
|
+ task_output = _process_dispatched_task(task, detect_anomalies)
|
|
|
+ if task_output is not None:
|
|
|
+ finished_indices.append(task_to_indices[task])
|
|
|
+ finished_outputs.append(task_output)
|
|
|
|
|
|
- finished_indices.append(task_to_indices[task])
|
|
|
- finished_outputs.append(tuple(deserialize_torch_tensor(tensor) for tensor in task.result().tensors))
|
|
|
-
|
|
|
- # count how many successes we have for each input sample
|
|
|
- sample_index = task_to_indices[task][0]
|
|
|
- num_successful_tasks[sample_index] += 1
|
|
|
- if num_successful_tasks[sample_index] == k_min:
|
|
|
- pending_samples -= 1
|
|
|
- if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
|
|
|
- t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
|
|
|
+ # count how many successes we have for each input sample
|
|
|
+ sample_index = task_to_indices[task][0]
|
|
|
+ num_successful_tasks[sample_index] += 1
|
|
|
+ if num_successful_tasks[sample_index] == k_min:
|
|
|
+ pending_samples -= 1
|
|
|
+ if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
|
|
|
+ t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
|
|
|
|
|
|
except Empty:
|
|
|
pass # we reached t_finish, this is normal behavior
|
|
@@ -296,3 +309,19 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
for task in pending_tasks:
|
|
|
task.cancel()
|
|
|
return finished_indices, finished_outputs
|
|
|
+
|
|
|
+
|
|
|
+def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
|
|
|
+ if task.exception() or task.cancelled():
|
|
|
+ logger.warning(f"Task {task} failed: {type(task.exception())}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ deserialized_outputs = []
|
|
|
+ for tensor in task.result().tensors:
|
|
|
+ deserialized_tensor = deserialize_torch_tensor(tensor)
|
|
|
+ if detect_anomalies and not deserialized_tensor.isfinite().all():
|
|
|
+ logger.error(f"Task {task} failed: output tensor contains nan/inf values")
|
|
|
+ return None
|
|
|
+ deserialized_outputs.append(deserialized_tensor)
|
|
|
+
|
|
|
+ return tuple(deserialized_outputs)
|