瀏覽代碼

Add anomaly detection to RemoteMixtureOfExperts (#132)

* RemoteMixtureOfExperts anomaly detection
Max Ryabinin 4 年之前
父節點
當前提交
aecff2286d
共有 4 個文件被更改,包括 109 次插入29 次删除
  1. 1 1
      hivemind/__init__.py
  2. 53 24
      hivemind/client/moe.py
  3. 1 1
      requirements.txt
  4. 54 3
      tests/test_moe.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.18'
+__version__ = '0.8.19'

+ 53 - 24
hivemind/client/moe.py

@@ -36,12 +36,14 @@ class RemoteMixtureOfExperts(nn.Module):
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_best: average this many highest-scoring experts to compute activations
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
     :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
+    :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
      Any expert that didn't manage to return output after that delay is considered unavailable
      Any expert that didn't manage to return output after that delay is considered unavailable
     """
     """
 
 
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
     def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
                  k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
-                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, **dht_kwargs):
+                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
+                 **dht_kwargs):
         super().__init__()
         super().__init__()
         if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
         if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
             uid_prefix += hivemind.dht.UID_DELIMITER
             uid_prefix += hivemind.dht.UID_DELIMITER
@@ -51,6 +53,7 @@ class RemoteMixtureOfExperts(nn.Module):
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.timeout_after_k_min = timeout_after_k_min
+        self.detect_anomalies = detect_anomalies
 
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self._expert_info = None  # expert['info'] from one of experts in the grid
         self._expert_info = None  # expert['info'] from one of experts in the grid
@@ -85,7 +88,7 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs)))
+            self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -156,11 +159,16 @@ class _RemoteCallMany(torch.autograd.Function):
     @classmethod
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
 
 
-        flat_inputs_cpu = [tensor.cpu() for tensor in flat_inputs]
+        flat_inputs_cpu = []
+        for tensor in flat_inputs:
+            if detect_anomalies and not tensor.isfinite().all():
+                raise ValueError("One of inputs has nan/inf values")
+            flat_inputs_cpu.append(tensor.cpu())
+
         flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
         flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
 
@@ -175,7 +183,7 @@ class _RemoteCallMany(torch.autograd.Function):
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
 
 
         alive_grid_indices, alive_flat_outputs = cls._collect_responses(
         alive_grid_indices, alive_flat_outputs = cls._collect_responses(
-            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
+            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
         if len(alive_grid_indices) == 0:
         if len(alive_grid_indices) == 0:
             raise TimeoutError("Forward pass: no alive experts responded within timeout.")
             raise TimeoutError("Forward pass: no alive experts responded within timeout.")
 
 
@@ -197,18 +205,25 @@ class _RemoteCallMany(torch.autograd.Function):
 
 
         # save individual outputs for backward pass
         # save individual outputs for backward pass
         ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
         ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
-        ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
+                                  detect_anomalies)
         return (mask,) + tuple(outputs)
         return (mask,) + tuple(outputs)
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
         assert not torch.is_grad_enabled()
-        info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
+         detect_anomalies) = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
         alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
 
 
         dummy_grad_mask, *flat_grad_outputs = raw_grads
         dummy_grad_mask, *flat_grad_outputs = raw_grads
-        flat_grad_outputs_cpu = [tensor.cpu() for tensor in flat_grad_outputs]
+
+        flat_grad_outputs_cpu = []
+        for tensor in flat_grad_outputs:
+            if detect_anomalies and not tensor.isfinite().all():
+                raise ValueError("One of gradients has nan/inf values")
+            flat_grad_outputs_cpu.append(tensor.cpu())
 
 
         num_samples, max_experts = dummy_grad_mask.shape
         num_samples, max_experts = dummy_grad_mask.shape
 
 
@@ -229,7 +244,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)
 
 
         backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
         backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
-            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
+            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
         if len(backward_survivor_indices) == 0:
         if len(backward_survivor_indices) == 0:
             raise TimeoutError("Backward pass: no alive experts responded within timeout.")
             raise TimeoutError("Backward pass: no alive experts responded within timeout.")
 
 
@@ -249,11 +264,11 @@ class _RemoteCallMany(torch.autograd.Function):
             # sum gradients from each expert
             # sum gradients from each expert
             grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
             grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
 
-        return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
 
 
     @staticmethod
     @staticmethod
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
     def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
-                           timeout_total: Optional[float], timeout_after_k_min: Optional[float]
+                           timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
                            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
                            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
         """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
         """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
         timeout_total = float('inf') if timeout_total is None else timeout_total
         timeout_total = float('inf') if timeout_total is None else timeout_total
@@ -275,20 +290,18 @@ class _RemoteCallMany(torch.autograd.Function):
                 task = finished_tasks.get(timeout=timeout)
                 task = finished_tasks.get(timeout=timeout)
                 pending_tasks.discard(task)
                 pending_tasks.discard(task)
 
 
-                if task.exception() or task.cancelled():
-                    logger.warning(f"Task {task} failed: {type(task.exception())}")
-                    continue
+                task_output = _process_dispatched_task(task, detect_anomalies)
+                if task_output is not None:
+                    finished_indices.append(task_to_indices[task])
+                    finished_outputs.append(task_output)
 
 
-                finished_indices.append(task_to_indices[task])
-                finished_outputs.append(tuple(deserialize_torch_tensor(tensor) for tensor in task.result().tensors))
-
-                # count how many successes we have for each input sample
-                sample_index = task_to_indices[task][0]
-                num_successful_tasks[sample_index] += 1
-                if num_successful_tasks[sample_index] == k_min:
-                    pending_samples -= 1
-                    if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
-                        t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
+                    # count how many successes we have for each input sample
+                    sample_index = task_to_indices[task][0]
+                    num_successful_tasks[sample_index] += 1
+                    if num_successful_tasks[sample_index] == k_min:
+                        pending_samples -= 1
+                        if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
+                            t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
 
 
         except Empty:
         except Empty:
             pass  # we reached t_finish, this is normal behavior
             pass  # we reached t_finish, this is normal behavior
@@ -296,3 +309,19 @@ class _RemoteCallMany(torch.autograd.Function):
             for task in pending_tasks:
             for task in pending_tasks:
                 task.cancel()
                 task.cancel()
         return finished_indices, finished_outputs
         return finished_indices, finished_outputs
+
+
+def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+    if task.exception() or task.cancelled():
+        logger.warning(f"Task {task} failed: {type(task.exception())}")
+        return None
+
+    deserialized_outputs = []
+    for tensor in task.result().tensors:
+        deserialized_tensor = deserialize_torch_tensor(tensor)
+        if detect_anomalies and not deserialized_tensor.isfinite().all():
+            logger.error(f"Task {task} failed: output tensor contains nan/inf values")
+            return None
+        deserialized_outputs.append(deserialized_tensor)
+
+    return tuple(deserialized_outputs)

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 PyYAML
 PyYAML
-torch>=1.3.0
+torch>=1.6.0
 numpy>=1.17
 numpy>=1.17
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 msgpack>=0.5.6
 msgpack>=0.5.6

+ 54 - 3
tests/test_moe.py

@@ -2,9 +2,11 @@ import grpc
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import torch
 import torch
+
 import hivemind
 import hivemind
-from hivemind.client.expert import DUMMY
 from hivemind import background_server
 from hivemind import background_server
+from hivemind.client.expert import DUMMY
+from hivemind.server import layers
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -30,6 +32,7 @@ def test_call_many():
     backward_k_min = 1
     backward_k_min = 1
     forward_timeout = None
     forward_timeout = None
     backward_timeout = None
     backward_timeout = None
+    detect_anomalies = False
     atol = 1e-5
     atol = 1e-5
 
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
@@ -40,8 +43,8 @@ def test_call_many():
         e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
         e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
 
 
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
-            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
-            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout, e1.info, inputs
+            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
+            forward_timeout, backward_timeout, detect_anomalies, e1.info, inputs
         )
         )
         assert mask.shape == (4, 3)
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, 64)
         assert expert_outputs.shape == (4, 3, 64)
@@ -169,3 +172,51 @@ def test_compute_expert_scores():
                     "compute_expert_scores returned incorrect score"
                     "compute_expert_scores returned incorrect score"
     finally:
     finally:
         dht.shutdown()
         dht.shutdown()
+
+
+@pytest.mark.forked
+def test_client_anomaly_detection():
+    HID_DIM = 16
+
+    experts = {}
+    for i in range(4):
+        expert = layers.name_to_block['ffn'](HID_DIM)
+        experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
+                                                        expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                        args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
+                                                        outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+                                                        max_batch_size=16,
+                                                        )
+
+    experts['expert.3'].expert.layers[0].weight.data[0, 0] = float('nan')
+
+    dht = hivemind.DHT(start=True, expiration=999)
+    server = hivemind.Server(dht, experts, num_connection_handlers=1)
+    server.start()
+    try:
+        server.ready.wait()
+
+        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix='expert.',
+                                               detect_anomalies=True)
+
+        input = torch.randn(1, 16)
+        input[0, 0] = float('nan')
+
+        with pytest.raises(ValueError):
+            dmoe(input)
+
+        input[0, 0] = 0
+        output = dmoe(input)
+
+        inf_loss = float('inf') * output.sum()
+        with pytest.raises(ValueError):
+            inf_loss.backward()
+
+        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix='expert.',
+                                               detect_anomalies=True)
+        output = dmoe(input)
+        assert output.isfinite().all()
+
+
+    finally:
+        server.shutdown()