Explorar el Código

Start the server with a terminal command, fix default config.yml (#108)

* Add terminal command for running the server

* Refactor layers, restore snake case for arguments

* Fix default config
Max Ryabinin hace 4 años
padre
commit
70dadfb8b5

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.5'
+__version__ = '0.8.6'

+ 1 - 1
hivemind/dht/__init__.py

@@ -328,7 +328,7 @@ class DHT(mp.Process):
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
             The keys in the returned dict are ordered same as in uid_prefixes.
         """
-        logger.warning("first_k_active is deprecated and will be removed in 0.8.6")
+        logger.warning("first_k_active is deprecated and will be removed in 0.8.7")
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],

+ 11 - 11
hivemind/server/__init__.py

@@ -1,23 +1,24 @@
 from __future__ import annotations
+
 import multiprocessing as mp
 import multiprocessing.synchronize
-import threading
 import random
+import threading
 from contextlib import contextmanager
 from functools import partial
+from typing import Dict, Optional, Tuple, List
 
 import torch
-from typing import Dict, Optional, Tuple, List
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.server.runtime import Runtime
-from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
-from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.checkpoint_saver import CheckpointSaver
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.layers import name_to_block, name_to_input
+from hivemind.server.runtime import Runtime
+from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
 
 logger = get_logger(__name__)
@@ -66,7 +67,7 @@ class Server(threading.Thread):
 
     @staticmethod
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
-               expert_cls='ffn', hidden_dim=1024, Optimizer=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
+               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
                device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
                *, start: bool, **kwargs) -> Server:
         """
@@ -76,12 +77,12 @@ class Server(threading.Thread):
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
          means "sample random experts between myprefix.0.0 and myprefix.255.255;
         :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
+        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
-        :param Optimizer: uses this optimizer to train all experts
+        :param optim_cls: uses this optimizer to train all experts
         :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: a list of peers that will introduce this node to the dht,\
          e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
@@ -112,7 +113,7 @@ class Server(threading.Thread):
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        Optimizer = Optimizer if Optimizer is not None else partial(torch.optim.SGD, lr=0.0)
+        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
 
         sample_input = name_to_input[expert_cls](4, hidden_dim)
@@ -129,7 +130,7 @@ class Server(threading.Thread):
             experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                          args_schema=args_schema,
                                                          outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
-                                                         opt=Optimizer(expert.parameters()),
+                                                         opt=optim_cls(expert.parameters()),
                                                          max_batch_size=max_batch_size,
                                                          )
         # actually start server
@@ -314,4 +315,3 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
         logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
                        f"{attempts_per_expert * num_experts} attempts")
     return found_uids
-

+ 2 - 74
hivemind/server/layers/__init__.py

@@ -1,79 +1,7 @@
 import torch
-import torch.nn as nn
-
-from hivemind.server.layers.dropout import DeterministicDropout
-
-
-class FeedforwardBlock(nn.Module):
-    def __init__(self, hid_dim):
-        super().__init__()
-        self.layers = nn.Sequential(
-            nn.Linear(hid_dim, 4 * hid_dim),
-            nn.LayerNorm(4 * hid_dim),
-            nn.ReLU(inplace=True),
-            nn.Linear(4 * hid_dim, 4 * hid_dim),
-            nn.LayerNorm(4 * hid_dim),
-            nn.ReLU(inplace=True),
-            nn.Linear(4 * hid_dim, hid_dim),
-        )
-
-    def forward(self, x):
-        return x + self.layers(x)
-
-
-class TransformerEncoderLayer(nn.Module):
-    """
-    A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
-    """
-
-    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
-        super().__init__()
-        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
-        # Implementation of Feedforward model
-        self.linear1 = nn.Linear(d_model, dim_feedforward)
-        self.dropout = nn.Dropout(dropout)
-        self.linear2 = nn.Linear(dim_feedforward, d_model)
-
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.dropout1 = nn.Dropout(dropout)
-        self.dropout2 = nn.Dropout(dropout)
-
-        self.activation = torch.nn.GELU()
-
-    def forward(self, src):
-        src.transpose_(0, 1)
-        src2 = self.self_attn(src, src, src)[0]
-        src = src + self.dropout1(src2)
-        src = self.norm1(src)
-        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
-        src = src + self.dropout2(src2)
-        src = self.norm2(src)
-        src.transpose_(0, 1)
-        return src
-
-
-class NopExpert(nn.Sequential):
-    def __init__(self, hid_dim):
-        super().__init__()
-        self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
-
-    def forward(self, x):
-        return x.clone()
-
-
-class DeterministicDropoutNetwork(nn.Module):
-    def __init__(self, hid_dim, dropout_prob):
-        super().__init__()
-        self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
-        self.activation = nn.ReLU()
-        self.dropout = DeterministicDropout(dropout_prob)
-        self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
-
-    def forward(self, x, mask):
-        x = self.linear_in(self.dropout(x, mask))
-        return self.linear_out(self.activation(x))
 
+from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLayer, NopExpert
+from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
 
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),

+ 60 - 0
hivemind/server/layers/common.py

@@ -0,0 +1,60 @@
+import torch
+from torch import nn as nn
+
+
+class FeedforwardBlock(nn.Module):
+    def __init__(self, hid_dim):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(hid_dim, 4 * hid_dim),
+            nn.LayerNorm(4 * hid_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(4 * hid_dim, 4 * hid_dim),
+            nn.LayerNorm(4 * hid_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(4 * hid_dim, hid_dim),
+        )
+
+    def forward(self, x):
+        return x + self.layers(x)
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
+    """
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = torch.nn.GELU()
+
+    def forward(self, src):
+        src.transpose_(0, 1)
+        src2 = self.self_attn(src, src, src)[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        src.transpose_(0, 1)
+        return src
+
+
+class NopExpert(nn.Sequential):
+    def __init__(self, hid_dim):
+        super().__init__()
+        self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
+
+    def forward(self, x):
+        return x.clone()

+ 14 - 0
hivemind/server/layers/dropout.py

@@ -1,5 +1,6 @@
 import torch.autograd
 import torch.nn as nn
+from torch import nn as nn
 
 
 class DeterministicDropoutFunction(torch.autograd.Function):
@@ -29,3 +30,16 @@ class DeterministicDropout(nn.Module):
             return DeterministicDropoutFunction.apply(x, self.keep_prob, mask)
         else:
             return x
+
+
+class DeterministicDropoutNetwork(nn.Module):
+    def __init__(self, hid_dim, dropout_prob):
+        super().__init__()
+        self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
+        self.activation = nn.ReLU()
+        self.dropout = DeterministicDropout(dropout_prob)
+        self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
+
+    def forward(self, x, mask):
+        x = self.linear_in(self.dropout(x, mask))
+        return self.linear_out(self.activation(x))

+ 17 - 0
hivemind/utils/threading.py

@@ -1,6 +1,10 @@
 import os
 from concurrent.futures import Future, ThreadPoolExecutor
 
+from hivemind.utils import get_logger
+
+logger = get_logger(__name__)
+
 EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 
 
@@ -11,3 +15,16 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
         GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
+
+
+def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+    """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """
+    try:
+        import resource  # local import to avoid ImportError for Windows users
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        new_soft = max(soft, new_soft)
+        new_hard = max(hard, new_hard)
+        logger.info(f"Increasing file limit: soft {soft}=>{new_soft}, hard {hard}=>{new_hard}")
+        return resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, new_hard))
+    except Exception as e:
+        logger.warning(f"Failed to increase file limit: {e}")

+ 10 - 12
scripts/config.yml

@@ -1,12 +1,10 @@
-listen_on: 0.0.0.0:* #'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
-num_experts: 1 #run this many identical experts
-expert_cls: ffn #expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.
-hidden_dim: 1024 #main dimension for expert_cls
-expert_prefix: expert #all expert uids will be {expert_prefix}.{index}
-expert_offset: 0 #expert uid will use indices in range(expert_offset, expert_offset + num_experts)
-max_batch_size: 16384 #total num examples in the same batch will not exceed this value
-optimizer: adam #if specified, all optimizers use learning rate=0
-no_dht: True #if specified, the server will not be attached to a dht
-initial_peers: "[]" #a list of peers that will introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]
-#dht_port: none #DHT node will listen on this port
-increase_file_limit: True #On *nix, this will increase the max number of processes a server can spawn before hitting "Too many open files"; Use at your own risk.
+listen_on: 0.0.0.0:*
+num_experts: 16
+expert_cls: ffn
+hidden_dim: 1024
+expert_pattern: expert.[0:4].[0:4]
+max_batch_size: 16384
+optimizer: adam
+no_dht: True
+initial_peers: "[]"
+increase_file_limit: True

+ 23 - 21
scripts/run_server.py

@@ -6,17 +6,19 @@ import resource
 import torch
 
 from hivemind.server import Server
+from hivemind.utils.threading import increase_file_limit
 
-if __name__ == '__main__':
+
+def main():
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
     parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
-    parser.add_argument('--num_experts', type=int, default=None, required=False, help="run this many experts")
-    parser.add_argument('--expert_pattern', type=str, default=None, required=False, help='all expert uids will follow'
-                        ' this pattern, e.g. "myexpert.[0:256].[0:1024]" will sample random expert uids'
-                        ' between myexpert.0.0 and myexpert.255.1023 . Use either num_experts and this or expert_uids')
+    parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
+    parser.add_argument('--expert_pattern', type=str, default=None, required=False,
+                        help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will sample random expert uids'
+                             ' between myexpert.0.0 and myexpert.255.1023 . Use either num_experts and this or expert_uids')
     parser.add_argument('--expert_uids', type=str, nargs="*", default=None, required=False,
                         help="specify the exact list of expert uids to create. Use either this or num_experts"
                              " and expert_pattern, not both")
@@ -26,39 +28,39 @@ if __name__ == '__main__':
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
-                        help='total num examples in the same batch will not exceed this value')
+                        help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
-    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[], help='one or more peers'
-                        ' that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
+    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
+                        help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
-    parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
-                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--increase_file_limit', action='store_true',
+                        help='On *nix, this will increase the max number of processes '
+                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
     # fmt:on
     args = vars(parser.parse_args())
     args.pop('config', None)
     optimizer = args.pop('optimizer')
     if optimizer == 'adam':
-        Optimizer = torch.optim.Adam
+        optim_cls = torch.optim.Adam
     elif optimizer == 'sgd':
-        Optimizer = partial(torch.optim.SGD, lr=0.01)
+        optim_cls = partial(torch.optim.SGD, lr=0.01)
     elif optimizer == 'none':
-        Optimizer = None
+        optim_cls = None
     else:
-        raise ValueError("Optimizer must be adam, sgd or none")
+        raise ValueError("optim_cls must be adam, sgd or none")
 
     if args.pop('increase_file_limit'):
-        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-        try:
-            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
-            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
-        except:
-            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+        increase_file_limit()
 
     try:
-        server = Server.create(**args, Optimizer=Optimizer, start=True, verbose=True)
+        server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True)
         server.join()
     finally:
         server.shutdown()
+
+
+if __name__ == '__main__':
+    main()

+ 3 - 0
setup.py

@@ -78,6 +78,9 @@ setup(
         'Topic :: Software Development :: Libraries',
         'Topic :: Software Development :: Libraries :: Python Modules',
     ],
+    entry_points={
+        'console_scripts': ['hivemind-server = scripts.run_server:main', ]
+    },
     # What does your project relate to?
     keywords='pytorch, deep learning, machine learning, gpu, distributed computing, volunteer computing, dht',
 )

+ 4 - 3
tests/benchmark_dht.py

@@ -1,11 +1,12 @@
-import time
 import argparse
 import random
+import time
 from warnings import warn
-import hivemind
+
 from tqdm import trange
 
-from test_utils import increase_file_limit
+import hivemind
+from hivemind.utils.threading import increase_file_limit
 
 
 def random_endpoint() -> hivemind.Endpoint:

+ 3 - 2
tests/benchmark_throughput.py

@@ -5,11 +5,12 @@ import sys
 import time
 
 import torch
-from hivemind.server import layers
-from test_utils import print_device_info, increase_file_limit
+from test_utils import print_device_info
 
 import hivemind
 from hivemind import find_open_port
+from hivemind.server import layers
+from hivemind.utils.threading import increase_file_limit
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):

+ 3 - 3
tests/test_moe.py

@@ -34,7 +34,7 @@ def test_call_many():
     atol = 1e-6
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
-                           Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
         inputs = torch.randn(4, 64, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
@@ -76,7 +76,7 @@ def test_call_many():
 
 def test_remote_module_call():
     with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
-                           Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
         real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
         fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
 
@@ -128,7 +128,7 @@ def test_determinism():
     mask = torch.randint(0, 1, (32, 1024))
 
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
-                           Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
         expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
 
         out = expert(xx, mask)

+ 1 - 1
tests/test_training.py

@@ -14,7 +14,7 @@ def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: f
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device='cpu', Optimzer=SGD, hidden_dim=64) as (server_endpoint, _):
+    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64) as (server_endpoint, _):
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))

+ 0 - 13
tests/test_utils/__init__.py

@@ -1,5 +1,3 @@
-from warnings import warn
-
 import torch
 
 
@@ -14,14 +12,3 @@ def print_device_info(device=None):
         print('Memory Usage:')
         print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
         print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
-
-
-def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
-    """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """
-    try:
-        import resource  # note: local import to avoid ImportError for those who don't have it
-        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-        print(f"Increasing file limit - soft {soft}=>{new_soft}, hard {hard}=>{new_hard}")
-        return resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, new_soft), max(hard, new_hard)))
-    except Exception as e:
-        warn(f"Failed to increase file limit: {e}")