Ver Fonte

Deterministic dropout layer (#31)

* Add DeterministicDropout with tests

* Enable determinism tests

* Create server before shutting down

* Revert changes

* Save mask for backward

* Save mask for backward

* Disable torch.jit (does not work for custom modules)

* Custom inputs for experts

* Custom inputs for experts

* Custom inputs for experts

* Check that CUDA is available before pinning

* Move dropout before first linear

* Move dropout before first linear

* Set requires_grad only for floating point inputs

* Fix missing commas in test

* Verify that the test can fail

* Revert "Verify that the test can fail"

This reverts commit bab56437a5c39221c99ab06f14be5ff122e036dc.

* Add docstring for DeterministicDropout

* Reflect current state of affairs with randomness in docstrings

* More verbosity in docstrings

* Move TODOs to comments
Max Ryabinin há 5 anos atrás
pai
commit
9573455c99

+ 10 - 5
hivemind/runtime/expert_backend.py

@@ -21,6 +21,9 @@ class ExpertBackend(nn.Module):
         - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
         - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
         - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
         - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
         - We recommend using experts that are ~invariant to the order in which they process batches
         - We recommend using experts that are ~invariant to the order in which they process batches
+        - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want to ensure consistency,
+            you should explicitly register these random variables as model outputs, so that they are sent back to the client.
+            See hivemind.utils.custom_layers.DeterministicDropout for an example
 
 
     :param opt: torch optimizer to be applied on every backward call
     :param opt: torch optimizer to be applied on every backward call
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
@@ -65,7 +68,8 @@ class ExpertBackend(nn.Module):
 
 
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
 
 
-           .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
+           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
+           .. For now, either register all buffers as outputs or avoid stateful experts
 
 
         """
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
@@ -89,15 +93,17 @@ class ExpertBackend(nn.Module):
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
 
 
-           .. todo state, randomness, etc
+           .. todo correct state handling (see forward)
 
 
            Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
            Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
         """
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
 
         with torch.enable_grad():
         with torch.enable_grad():
-            args = [tensor.detach().requires_grad_(True) for tensor in args]
-            kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
+            args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
+                    else tensor.detach() for tensor in args]
+            kwargs = {input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
+                                  else tensor.detach()) for input_key, tensor in kwargs.items()}
 
 
             outputs = self.expert(*args, **kwargs)
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
@@ -129,4 +135,3 @@ class ExpertBackend(nn.Module):
     def get_pools(self) -> Sequence[TaskPool]:
     def get_pools(self) -> Sequence[TaskPool]:
         """ return all pools that should be processed by ``Runtime`` """
         """ return all pools that should be processed by ``Runtime`` """
         return self.forward_pool, self.backward_pool
         return self.forward_pool, self.backward_pool
-

+ 31 - 0
hivemind/utils/custom_layers.py

@@ -0,0 +1,31 @@
+import torch.autograd
+import torch.nn as nn
+
+
+class DeterministicDropoutFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, keep_prob, mask):
+        ctx.keep_prob = keep_prob
+        ctx.save_for_backward(mask)
+        return x * mask / keep_prob
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return ctx.saved_tensors[0] * grad_output / ctx.keep_prob, None, None
+
+
+class DeterministicDropout(nn.Module):
+    """
+    Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
+    Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
+    """
+
+    def __init__(self, drop_prob):
+        super().__init__()
+        self.keep_prob = 1 - drop_prob
+
+    def forward(self, x, mask):
+        if self.training:
+            return DeterministicDropoutFunction.apply(x, self.keep_prob, mask)
+        else:
+            return x

+ 1 - 1
hivemind/utils/proto.py

@@ -45,7 +45,7 @@ class BatchTensorProto(TensorProto):
     @classmethod
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
     def from_tensor(cls, tensor: torch.Tensor):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
-                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned())
+                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=torch.cuda.is_available() and tensor.is_pinned())
 
 
     def make_empty(self, batch_size, **kwargs):
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 23 - 1
tests/test_moe.py

@@ -43,6 +43,27 @@ def test_remote_module_call():
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
 
 
 
 
+def test_determinism():
+    rtol = 0
+    atol = 1e-6
+
+    xx = torch.randn(32, 1024, requires_grad=True)
+    mask = torch.randint(0, 1, (32, 1024))
+
+    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
+                           no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
+        expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
+
+        out = expert(xx, mask)
+        out_rerun = expert(xx, mask)
+
+        grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
+        grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
+
+    assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
+    assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
+
+
 def test_compute_expert_scores():
 def test_compute_expert_scores():
     try:
     try:
         dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
         dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
@@ -69,4 +90,5 @@ def test_compute_expert_scores():
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     test_remote_module_call()
     test_remote_module_call()
-    test_compute_expert_scores()
+    test_compute_expert_scores()
+    test_determinism()

+ 21 - 3
tests/test_utils/layers.py

@@ -1,5 +1,7 @@
 import torch
 import torch
-from torch import nn as nn
+import torch.nn as nn
+
+from hivemind.utils.custom_layers import DeterministicDropout
 
 
 
 
 class FeedforwardBlock(nn.Module):
 class FeedforwardBlock(nn.Module):
@@ -60,9 +62,25 @@ class NopExpert(nn.Sequential):
         return x.clone()
         return x.clone()
 
 
 
 
+class DeterministicDropoutNetwork(nn.Module):
+    def __init__(self, hid_dim, dropout_prob):
+        super().__init__()
+        self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
+        self.activation = nn.ReLU()
+        self.dropout = DeterministicDropout(dropout_prob)
+        self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
+
+    def forward(self, x, mask):
+        x = self.linear_in(self.dropout(x, mask))
+        return self.linear_out(self.activation(x))
+
+
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
-                 'nop': lambda hid_dim: NopExpert(hid_dim)}
+                 'nop': lambda hid_dim: NopExpert(hid_dim),
+                 'det_dropout': lambda hid_dim: DeterministicDropoutNetwork(hid_dim, dropout_prob=0.2)}
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
                  'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
-                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))}
+                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
+                 'det_dropout': lambda batch_size, hid_dim:
+                 (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}

+ 10 - 4
tests/test_utils/run_server.py

@@ -5,7 +5,7 @@ import argparse
 
 
 import torch
 import torch
 import hivemind
 import hivemind
-from .layers import name_to_block
+from .layers import name_to_block, name_to_input
 
 
 
 
 def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
 def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
@@ -27,7 +27,7 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
             dht_root = hivemind.DHTNode(
             dht_root = hivemind.DHTNode(
                 *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
                 *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
             print(f"Initializing DHT with port {dht_root.port}")
             print(f"Initializing DHT with port {dht_root.port}")
-            initial_peers = (('localhost', dht_root.port), )
+            initial_peers = (('localhost', dht_root.port),)
         else:
         else:
             print("Bootstrapping dht with peers:", initial_peers)
             print("Bootstrapping dht with peers:", initial_peers)
             if root_port is not None:
             if root_port is not None:
@@ -38,14 +38,20 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
         if verbose:
         if verbose:
             print(f"Running dht node on port {dht.port}")
             print(f"Running dht node on port {dht.port}")
 
 
+    sample_input = name_to_input[expert_cls](4, hidden_dim)
+    if isinstance(sample_input, tuple):
+        args_schema = tuple(hivemind.BatchTensorProto.from_tensor(arg) for arg in sample_input)
+    else:
+        args_schema = (hivemind.BatchTensorProto.from_tensor(sample_input),)
+
     # initialize experts
     # initialize experts
     experts = {}
     experts = {}
     for i in range(num_experts):
     for i in range(num_experts):
-        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
+        expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
         opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
-                                                     args_schema=(hivemind.BatchTensorProto(hidden_dim),),
+                                                     args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorProto(hidden_dim),
                                                      outputs_schema=hivemind.BatchTensorProto(hidden_dim),
                                                      max_batch_size=max_batch_size,
                                                      max_batch_size=max_batch_size,
                                                      )
                                                      )