Эх сурвалжийг харах

Deterministic dropout layer (#31)

* Add DeterministicDropout with tests

* Enable determinism tests

* Create server before shutting down

* Revert changes

* Save mask for backward

* Save mask for backward

* Disable torch.jit (does not work for custom modules)

* Custom inputs for experts

* Custom inputs for experts

* Custom inputs for experts

* Check that CUDA is available before pinning

* Move dropout before first linear

* Move dropout before first linear

* Set requires_grad only for floating point inputs

* Fix missing commas in test

* Verify that the test can fail

* Revert "Verify that the test can fail"

This reverts commit bab56437a5c39221c99ab06f14be5ff122e036dc.

* Add docstring for DeterministicDropout

* Reflect current state of affairs with randomness in docstrings

* More verbosity in docstrings

* Move TODOs to comments
Max Ryabinin 5 жил өмнө
parent
commit
9573455c99

+ 10 - 5
hivemind/runtime/expert_backend.py

@@ -21,6 +21,9 @@ class ExpertBackend(nn.Module):
         - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
         - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
         - We recommend using experts that are ~invariant to the order in which they process batches
+        - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want to ensure consistency,
+            you should explicitly register these random variables as model outputs, so that they are sent back to the client.
+            See hivemind.utils.custom_layers.DeterministicDropout for an example
 
     :param opt: torch optimizer to be applied on every backward call
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
@@ -65,7 +68,8 @@ class ExpertBackend(nn.Module):
 
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
 
-           .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
+           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
+           .. For now, either register all buffers as outputs or avoid stateful experts
 
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
@@ -89,15 +93,17 @@ class ExpertBackend(nn.Module):
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
 
-           .. todo state, randomness, etc
+           .. todo correct state handling (see forward)
 
            Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
         with torch.enable_grad():
-            args = [tensor.detach().requires_grad_(True) for tensor in args]
-            kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
+            args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
+                    else tensor.detach() for tensor in args]
+            kwargs = {input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
+                                  else tensor.detach()) for input_key, tensor in kwargs.items()}
 
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
@@ -129,4 +135,3 @@ class ExpertBackend(nn.Module):
     def get_pools(self) -> Sequence[TaskPool]:
         """ return all pools that should be processed by ``Runtime`` """
         return self.forward_pool, self.backward_pool
-

+ 31 - 0
hivemind/utils/custom_layers.py

@@ -0,0 +1,31 @@
+import torch.autograd
+import torch.nn as nn
+
+
+class DeterministicDropoutFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, keep_prob, mask):
+        ctx.keep_prob = keep_prob
+        ctx.save_for_backward(mask)
+        return x * mask / keep_prob
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return ctx.saved_tensors[0] * grad_output / ctx.keep_prob, None, None
+
+
+class DeterministicDropout(nn.Module):
+    """
+    Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
+    Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
+    """
+
+    def __init__(self, drop_prob):
+        super().__init__()
+        self.keep_prob = 1 - drop_prob
+
+    def forward(self, x, mask):
+        if self.training:
+            return DeterministicDropoutFunction.apply(x, self.keep_prob, mask)
+        else:
+            return x

+ 1 - 1
hivemind/utils/proto.py

@@ -45,7 +45,7 @@ class BatchTensorProto(TensorProto):
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
-                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned())
+                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=torch.cuda.is_available() and tensor.is_pinned())
 
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 23 - 1
tests/test_moe.py

@@ -43,6 +43,27 @@ def test_remote_module_call():
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
 
 
+def test_determinism():
+    rtol = 0
+    atol = 1e-6
+
+    xx = torch.randn(32, 1024, requires_grad=True)
+    mask = torch.randint(0, 1, (32, 1024))
+
+    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
+                           no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
+        expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
+
+        out = expert(xx, mask)
+        out_rerun = expert(xx, mask)
+
+        grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
+        grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
+
+    assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
+    assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
+
+
 def test_compute_expert_scores():
     try:
         dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
@@ -69,4 +90,5 @@ def test_compute_expert_scores():
 
 if __name__ == '__main__':
     test_remote_module_call()
-    test_compute_expert_scores()
+    test_compute_expert_scores()
+    test_determinism()

+ 21 - 3
tests/test_utils/layers.py

@@ -1,5 +1,7 @@
 import torch
-from torch import nn as nn
+import torch.nn as nn
+
+from hivemind.utils.custom_layers import DeterministicDropout
 
 
 class FeedforwardBlock(nn.Module):
@@ -60,9 +62,25 @@ class NopExpert(nn.Sequential):
         return x.clone()
 
 
+class DeterministicDropoutNetwork(nn.Module):
+    def __init__(self, hid_dim, dropout_prob):
+        super().__init__()
+        self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
+        self.activation = nn.ReLU()
+        self.dropout = DeterministicDropout(dropout_prob)
+        self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
+
+    def forward(self, x, mask):
+        x = self.linear_in(self.dropout(x, mask))
+        return self.linear_out(self.activation(x))
+
+
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
-                 'nop': lambda hid_dim: NopExpert(hid_dim)}
+                 'nop': lambda hid_dim: NopExpert(hid_dim),
+                 'det_dropout': lambda hid_dim: DeterministicDropoutNetwork(hid_dim, dropout_prob=0.2)}
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
-                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))}
+                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
+                 'det_dropout': lambda batch_size, hid_dim:
+                 (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}

+ 10 - 4
tests/test_utils/run_server.py

@@ -5,7 +5,7 @@ import argparse
 
 import torch
 import hivemind
-from .layers import name_to_block
+from .layers import name_to_block, name_to_input
 
 
 def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
@@ -27,7 +27,7 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
             dht_root = hivemind.DHTNode(
                 *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
             print(f"Initializing DHT with port {dht_root.port}")
-            initial_peers = (('localhost', dht_root.port), )
+            initial_peers = (('localhost', dht_root.port),)
         else:
             print("Bootstrapping dht with peers:", initial_peers)
             if root_port is not None:
@@ -38,14 +38,20 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
         if verbose:
             print(f"Running dht node on port {dht.port}")
 
+    sample_input = name_to_input[expert_cls](4, hidden_dim)
+    if isinstance(sample_input, tuple):
+        args_schema = tuple(hivemind.BatchTensorProto.from_tensor(arg) for arg in sample_input)
+    else:
+        args_schema = (hivemind.BatchTensorProto.from_tensor(sample_input),)
+
     # initialize experts
     experts = {}
     for i in range(num_experts):
-        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
+        expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
-                                                     args_schema=(hivemind.BatchTensorProto(hidden_dim),),
+                                                     args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorProto(hidden_dim),
                                                      max_batch_size=max_batch_size,
                                                      )