Pārlūkot izejas kodu

Add anomaly detection to RemoteMixtureOfExperts (#132)

* RemoteMixtureOfExperts anomaly detection
Max Ryabinin 4 gadi atpakaļ
vecāks
revīzija
aecff2286d
4 mainītis faili ar 109 papildinājumiem un 29 dzēšanām
  1. 1 1
      hivemind/__init__.py
  2. 53 24
      hivemind/client/moe.py
  3. 1 1
      requirements.txt
  4. 54 3
      tests/test_moe.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.18'
+__version__ = '0.8.19'

+ 53 - 24
hivemind/client/moe.py

@@ -36,12 +36,14 @@ class RemoteMixtureOfExperts(nn.Module):
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
      Any expert that didn't manage to return output after that delay is considered unavailable
     """
 
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
-                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, **dht_kwargs):
+                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
+                 **dht_kwargs):
         super().__init__()
         if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
             uid_prefix += hivemind.dht.UID_DELIMITER
@@ -51,6 +53,7 @@ class RemoteMixtureOfExperts(nn.Module):
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
+        self.detect_anomalies = detect_anomalies
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self._expert_info = None  # expert['info'] from one of experts in the grid
@@ -85,7 +88,7 @@ class RemoteMixtureOfExperts(nn.Module):
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs)))
+            self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -156,11 +159,16 @@ class _RemoteCallMany(torch.autograd.Function):
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
 
-        flat_inputs_cpu = [tensor.cpu() for tensor in flat_inputs]
+        flat_inputs_cpu = []
+        for tensor in flat_inputs:
+            if detect_anomalies and not tensor.isfinite().all():
+                raise ValueError("One of inputs has nan/inf values")
+            flat_inputs_cpu.append(tensor.cpu())
+
         flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
@@ -175,7 +183,7 @@ class _RemoteCallMany(torch.autograd.Function):
                 pending_tasks[new_task] = (i, j)
 
         alive_grid_indices, alive_flat_outputs = cls._collect_responses(
-            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
+            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
         if len(alive_grid_indices) == 0:
             raise TimeoutError("Forward pass: no alive experts responded within timeout.")
 
@@ -197,18 +205,25 @@ class _RemoteCallMany(torch.autograd.Function):
 
         # save individual outputs for backward pass
         ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
-        ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
+                                  detect_anomalies)
         return (mask,) + tuple(outputs)
 
     @classmethod
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
-        info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
+         detect_anomalies) = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
 
         dummy_grad_mask, *flat_grad_outputs = raw_grads
-        flat_grad_outputs_cpu = [tensor.cpu() for tensor in flat_grad_outputs]
+
+        flat_grad_outputs_cpu = []
+        for tensor in flat_grad_outputs:
+            if detect_anomalies and not tensor.isfinite().all():
+                raise ValueError("One of gradients has nan/inf values")
+            flat_grad_outputs_cpu.append(tensor.cpu())
 
         num_samples, max_experts = dummy_grad_mask.shape
 
@@ -229,7 +244,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks[new_task] = (i, j)
 
         backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
-            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
+            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
         if len(backward_survivor_indices) == 0:
             raise TimeoutError("Backward pass: no alive experts responded within timeout.")
 
@@ -249,11 +264,11 @@ class _RemoteCallMany(torch.autograd.Function):
             # sum gradients from each expert
             grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
-        return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
 
     @staticmethod
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
-                           timeout_total: Optional[float], timeout_after_k_min: Optional[float]
+                           timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
                            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
         """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
         timeout_total = float('inf') if timeout_total is None else timeout_total
@@ -275,20 +290,18 @@ class _RemoteCallMany(torch.autograd.Function):
                 task = finished_tasks.get(timeout=timeout)
                 pending_tasks.discard(task)
 
-                if task.exception() or task.cancelled():
-                    logger.warning(f"Task {task} failed: {type(task.exception())}")
-                    continue
+                task_output = _process_dispatched_task(task, detect_anomalies)
+                if task_output is not None:
+                    finished_indices.append(task_to_indices[task])
+                    finished_outputs.append(task_output)
 
-                finished_indices.append(task_to_indices[task])
-                finished_outputs.append(tuple(deserialize_torch_tensor(tensor) for tensor in task.result().tensors))
-
-                # count how many successes we have for each input sample
-                sample_index = task_to_indices[task][0]
-                num_successful_tasks[sample_index] += 1
-                if num_successful_tasks[sample_index] == k_min:
-                    pending_samples -= 1
-                    if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
-                        t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
+                    # count how many successes we have for each input sample
+                    sample_index = task_to_indices[task][0]
+                    num_successful_tasks[sample_index] += 1
+                    if num_successful_tasks[sample_index] == k_min:
+                        pending_samples -= 1
+                        if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
+                            t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
 
         except Empty:
             pass  # we reached t_finish, this is normal behavior
@@ -296,3 +309,19 @@ class _RemoteCallMany(torch.autograd.Function):
             for task in pending_tasks:
                 task.cancel()
         return finished_indices, finished_outputs
+
+
+def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+    if task.exception() or task.cancelled():
+        logger.warning(f"Task {task} failed: {type(task.exception())}")
+        return None
+
+    deserialized_outputs = []
+    for tensor in task.result().tensors:
+        deserialized_tensor = deserialize_torch_tensor(tensor)
+        if detect_anomalies and not deserialized_tensor.isfinite().all():
+            logger.error(f"Task {task} failed: output tensor contains nan/inf values")
+            return None
+        deserialized_outputs.append(deserialized_tensor)
+
+    return tuple(deserialized_outputs)

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 PyYAML
-torch>=1.3.0
+torch>=1.6.0
 numpy>=1.17
 prefetch_generator>=1.0.1
 msgpack>=0.5.6

+ 54 - 3
tests/test_moe.py

@@ -2,9 +2,11 @@ import grpc
 import numpy as np
 import pytest
 import torch
+
 import hivemind
-from hivemind.client.expert import DUMMY
 from hivemind import background_server
+from hivemind.client.expert import DUMMY
+from hivemind.server import layers
 
 
 @pytest.mark.forked
@@ -30,6 +32,7 @@ def test_call_many():
     backward_k_min = 1
     forward_timeout = None
     backward_timeout = None
+    detect_anomalies = False
     atol = 1e-5
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
@@ -40,8 +43,8 @@ def test_call_many():
         e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
 
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
-            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
-            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout, e1.info, inputs
+            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
+            forward_timeout, backward_timeout, detect_anomalies, e1.info, inputs
         )
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, 64)
@@ -169,3 +172,51 @@ def test_compute_expert_scores():
                     "compute_expert_scores returned incorrect score"
     finally:
         dht.shutdown()
+
+
+@pytest.mark.forked
+def test_client_anomaly_detection():
+    HID_DIM = 16
+
+    experts = {}
+    for i in range(4):
+        expert = layers.name_to_block['ffn'](HID_DIM)
+        experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
+                                                        expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                        args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
+                                                        outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+                                                        max_batch_size=16,
+                                                        )
+
+    experts['expert.3'].expert.layers[0].weight.data[0, 0] = float('nan')
+
+    dht = hivemind.DHT(start=True, expiration=999)
+    server = hivemind.Server(dht, experts, num_connection_handlers=1)
+    server.start()
+    try:
+        server.ready.wait()
+
+        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix='expert.',
+                                               detect_anomalies=True)
+
+        input = torch.randn(1, 16)
+        input[0, 0] = float('nan')
+
+        with pytest.raises(ValueError):
+            dmoe(input)
+
+        input[0, 0] = 0
+        output = dmoe(input)
+
+        inf_loss = float('inf') * output.sum()
+        with pytest.raises(ValueError):
+            inf_loss.backward()
+
+        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix='expert.',
+                                               detect_anomalies=True)
+        output = dmoe(input)
+        assert output.isfinite().all()
+
+
+    finally:
+        server.shutdown()