浏览代码

Add tool for custom user experts (#189)

- Implement a decorator (register_expert_class) fir user-defined modules
- New command-line argument for run_server.py which represents the path to file to user-defined modules
- Moved all existing layers on this API
- Added some tests, but they are marked to be skipped due to a mistake in backgrond_server

Co-authored-by: justheuristic <justheuristic@gmail.com>
romakail 4 年之前
父节点
当前提交
ca5c7610ae

+ 3 - 0
hivemind/hivemind_cli/run_server.py

@@ -56,6 +56,9 @@ def main():
     parser.add_argument('--stats_report_interval', type=int, required=False,
                         help='Interval between two reports of batch processing performance statistics')
 
+    parser.add_argument('--custom_module_path', type=str, required=False,
+                        help='Path of a file with custom nn.modules, wrapped into special decorator')
+
     # fmt:on
     args = vars(parser.parse_args())
     args.pop('config', None)

+ 7 - 3
hivemind/server/__init__.py

@@ -17,7 +17,8 @@ from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_direct
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
+from hivemind.server.layers import name_to_block, name_to_input
+from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
@@ -72,8 +73,8 @@ class Server(threading.Thread):
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
                num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
                device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
-               **kwargs) -> Server:
+               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
+               *, start: bool, **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -109,6 +110,9 @@ class Server(threading.Thread):
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
         """
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+
         if len(kwargs) != 0:
             logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block

+ 3 - 0
hivemind/server/expert_backend.py

@@ -96,6 +96,9 @@ class ExpertBackend:
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
 
+        if args[0].shape[0] == 0:
+            raise RuntimeError("Batch should contain more than 0 samples")
+
         with torch.no_grad():
             outputs = self.expert(*args, **kwargs)
 

+ 6 - 13
hivemind/server/layers/__init__.py

@@ -1,19 +1,12 @@
 import torch
 
-from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLayer, NopExpert
-from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
-from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
+name_to_block = {}
+name_to_input = {}
 
-name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
-                 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, dim_feedforward=4 * hid_dim, nhead=16),
-                 'nop': lambda hid_dim: NopExpert(hid_dim),
-                 'det_dropout': lambda hid_dim: DeterministicDropoutNetwork(hid_dim, dropout_prob=0.2)}
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
+from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
 
-name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
-                 'transformer': lambda batch_size, hid_dim:
-                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, 128), dtype=torch.bool)),
-                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
-                 'det_dropout': lambda batch_size, hid_dim:
-                 (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}
+import hivemind.server.layers.common
+import hivemind.server.layers.dropout
 
 schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 18 - 0
hivemind/server/layers/common.py

@@ -1,6 +1,8 @@
 import torch
 from torch import nn as nn
 
+from hivemind.server.layers.custom_experts import register_expert_class
+
 
 # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
 @torch.jit.script
@@ -8,7 +10,10 @@ def gelu_fast(x):
     return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
 
 
+ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+@register_expert_class('ffn', ffn_sample_input)
 class FeedforwardBlock(nn.Module):
+
     def __init__(self, hid_dim):
         super().__init__()
         self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
@@ -58,7 +63,20 @@ class TransformerEncoderLayer(nn.Module):
         return src
 
 
+transformer_sample_input = lambda batch_size, hid_dim: \
+    (torch.empty((batch_size, 128, hid_dim)), \
+    torch.empty((batch_size, 128), dtype=torch.bool))
+@register_expert_class('transformer', transformer_sample_input)
+class TunedTransformer(TransformerEncoderLayer):
+
+    def __init__(self, hid_dim):
+        super().__init__(hid_dim, dim_feedforward=4 * hid_dim, nhead=16)
+
+
+nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
+@register_expert_class('nop', nop_sample_input)
 class NopExpert(nn.Sequential):
+
     def __init__(self, hid_dim):
         super().__init__()
         self.w = nn.Parameter(torch.zeros(0), requires_grad=True)

+ 34 - 0
hivemind/server/layers/custom_experts.py

@@ -0,0 +1,34 @@
+import os
+import importlib
+from typing import Callable, Type
+
+import torch
+import torch.nn as nn
+
+from hivemind.server.layers import name_to_block, name_to_input
+
+
+def add_custom_models_from_file(path: str):
+    spec = importlib.util.spec_from_file_location(
+        "custom_module", os.path.abspath(path))
+    foo = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(foo)
+
+
+def register_expert_class(name: str, sample_input: Callable[[int, int], torch.tensor]):
+    """
+    Adds a custom user expert to hivemind server.
+    :param name: the name of the expert. It shouldn't coincide with existing modules\
+        ('ffn', 'transformer', 'nop', 'det_dropout')
+    :param sample_input: a function which gets batch_size and hid_dim and outputs a \
+        sample of an input in the module
+    :unchanged module
+    """
+    def _register_expert_class(custom_class: Type[nn.Module]):
+        if name in name_to_block or name in name_to_input:
+            raise RuntimeError("The class might already exist or be added twice")
+        name_to_block[name] = custom_class
+        name_to_input[name] = sample_input
+
+        return custom_class
+    return _register_expert_class

+ 8 - 2
hivemind/server/layers/dropout.py

@@ -1,8 +1,11 @@
 import torch.autograd
 from torch import nn as nn
 
+from hivemind.server.layers.custom_experts import register_expert_class
+
 
 class DeterministicDropoutFunction(torch.autograd.Function):
+
     @staticmethod
     def forward(ctx, x, keep_prob, mask):
         ctx.keep_prob = keep_prob
@@ -30,9 +33,12 @@ class DeterministicDropout(nn.Module):
         else:
             return x
 
-
+dropout_sample_input = lambda batch_size, hid_dim: \
+    (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))
+@register_expert_class('det_dropout', dropout_sample_input)
 class DeterministicDropoutNetwork(nn.Module):
-    def __init__(self, hid_dim, dropout_prob):
+
+    def __init__(self, hid_dim, dropout_prob=0.2):
         super().__init__()
         self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
         self.activation = nn.ReLU()

+ 36 - 0
tests/custom_networks.py

@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from hivemind.server.layers.custom_experts import register_expert_class
+
+sample_input = lambda batch_size, hidden_dim : torch.empty((batch_size, hidden_dim))
+@register_expert_class('perceptron', sample_input)
+class MultilayerPerceptron(nn.Module):
+    def __init__(self, hidden_dim, num_classes=10):
+        super(MultilayerPerceptron, self).__init__()
+        self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
+        self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
+        self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
+
+    def forward(self, x):
+        x = F.relu(self.layer1(x))
+        x = F.relu(self.layer2(x))
+        x = self.layer3(x)
+        return x
+
+multihead_sample_input = lambda batch_size, hidden_dim : \
+    (torch.empty((batch_size, hidden_dim)),
+    torch.empty((batch_size, 2 * hidden_dim)),
+    torch.empty((batch_size, 3 * hidden_dim)),)
+@register_expert_class('multihead', multihead_sample_input)
+class MultiheadNetwork(nn.Module):
+    def __init__(self, hidden_dim, num_classes=10):
+        super(MultiheadNetwork, self).__init__()
+        self.layer1 = nn.Linear(hidden_dim, num_classes)
+        self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
+        self.layer3 = nn.Linear(3 * hidden_dim, num_classes)
+
+    def forward(self, x1, x2, x3):
+        x = self.layer1(x1) + self.layer2(x2) + self.layer3(x3)
+        return x

+ 50 - 0
tests/test_custom_expert.py

@@ -0,0 +1,50 @@
+import os
+import pytest
+from typing import Optional
+
+import torch
+
+import hivemind
+from hivemind import RemoteExpert, background_server
+
+@pytest.mark.forked
+def test_custom_expert(port: Optional[int] = None, hid_dim=16):
+    with background_server(
+        expert_cls='perceptron', num_experts=2, device='cpu',
+        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
+
+        expert0 = RemoteExpert('expert.0', server_endpoint)
+        expert1 = RemoteExpert('expert.1', server_endpoint)
+
+        for batch_size in (1, 4):
+            batch = torch.randn(batch_size, hid_dim)
+
+            output0 = expert0(batch)
+            output1 = expert1(batch)
+
+            loss = output0.sum()
+            loss.backward()
+            loss = output1.sum()
+            loss.backward()
+
+@pytest.mark.forked
+def test_multihead_expert(port: Optional[int] = None, hid_dim=16):
+    with background_server(
+        expert_cls='multihead', num_experts=2, device='cpu',
+        hidden_dim=hid_dim, num_handlers=2, no_dht=True,
+        custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
+
+        expert0 = RemoteExpert('expert.0', server_endpoint)
+        expert1 = RemoteExpert('expert.1', server_endpoint)
+
+        for batch_size in (1, 4):
+            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim))
+
+            output0 = expert0(*batch)
+            output1 = expert1(*batch)
+
+            loss = output0.sum()
+            loss.backward()
+            loss = output1.sum()
+            loss.backward()