浏览代码

Speed up tests, shutdown threads via Event (#82)

* Speed up tests, shutdown threads via Event
Max Ryabinin 5 年之前
父节点
当前提交
53f5ab6971

+ 20 - 17
hivemind/client/moe.py

@@ -47,8 +47,10 @@ class RemoteMixtureOfExperts(nn.Module):
         super().__init__()
         self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
         self.loop = loop or asyncio.new_event_loop()
+        # fmt:off
         assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
             "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
+        # fmt:on
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
@@ -59,27 +61,25 @@ class RemoteMixtureOfExperts(nn.Module):
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
-        Choose k best experts with beam search, then call chosen experts and average their outputs.
-        :param input: a tensor of values that are used to estimate gating function, batch-first
+        Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all
+        dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions)
+
+        :param input: a tensor of values that are used to estimate gating function, batch-first.
         :param args: extra positional parameters that will be passed to each expert after input, batch-first
         :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
         :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
         """
-        if self.allow_broadcasting and input.ndim != 2:
-            # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
-            flattened_dims = input.shape[:-1]
-            input_flat = input.view(-1, input.shape[-1])
-            args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
-            kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
-            out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
-            return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
+        if input.ndim != 2:
+            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
+        else:
+            input_for_gating = input
 
         # 1. compute scores and find most appropriate experts with beam search
-        grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
+        grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
 
         async def _search():
             coroutines = [asyncio.create_task(self.beam_search(
-                [dim_scores[i] for dim_scores in grid_scores], self.k_best))
+                [dim_scores[i] for dim_scores in grid_scores], self.k_best), name=f'beam_search_{i}')
                 for i in range(len(input))]
             return list(await asyncio.gather(*coroutines))
 
@@ -184,7 +184,9 @@ class RemoteMixtureOfExperts(nn.Module):
     def outputs_schema(self):
         if self._outputs_schema is None:
             # grab some expert to set ensemble output shape
-            dummy_scores = self.proj(torch.randn(self.proj.in_features)).cpu().split_with_sizes(self.grid_size, dim=-1)
+            proj_device = self.proj.weight.device
+            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
+            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
             dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
             self._outputs_schema = dummy_experts[0].info['outputs_schema']
         return self._outputs_schema
@@ -212,7 +214,7 @@ class _RemoteCallMany(torch.autograd.Function):
         async def _forward():
             # dispatch tasks to all remote experts, await responses
             pending_tasks = {
-                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
+                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]), name=f'forward_expert_{j}_for_{i}')
                 for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
             }
             alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
@@ -259,8 +261,8 @@ class _RemoteCallMany(torch.autograd.Function):
             for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                         inputs_per_expert, grad_outputs_per_expert):
                 pending_tasks.add(asyncio.create_task(
-                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)
-                ))
+                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij),
+                    name=f'backward_expert_{j}_for_{i}'))
 
             backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
                 pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
@@ -280,6 +282,7 @@ class _RemoteCallMany(torch.autograd.Function):
                 grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
 
             return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+
         return loop.run_until_complete(_backward())
 
     @staticmethod
@@ -308,7 +311,7 @@ class _RemoteCallMany(torch.autograd.Function):
     async def _wait_for_responses(
             pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
             num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
-            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
+    ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
         """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
         timeout_total = float('inf') if timeout_total is None else timeout_total
         timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min

+ 2 - 1
hivemind/dht/__init__.py

@@ -96,7 +96,8 @@ class DHT(mp.Process):
             node = await DHTNode.create(
                 initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
                 num_workers=self.max_workers or 1, **self.kwargs)
-            self._port.value = node.port
+            if node.port is not None:
+                self._port.value = node.port
             self.ready.set()
 
             while True:

+ 6 - 3
hivemind/server/__init__.py

@@ -10,7 +10,9 @@ from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.checkpoint_saver import CheckpointSaver
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port
+from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
+
+logger = get_logger(__name__)
 
 
 class Server(threading.Thread):
@@ -81,10 +83,10 @@ class Server(threading.Thread):
         for process in self.conn_handlers:
             process.join()
         if self.dht:
-            dht_handler_thread.stop = True
+            dht_handler_thread.stop.set()
             dht_handler_thread.join()
         if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop = True
+            self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
     def run_in_background(self, await_ready=True, timeout=None):
@@ -121,5 +123,6 @@ class Server(threading.Thread):
 
         if self.dht is not None:
             self.dht.shutdown()
+            self.dht.join()
 
         self.runtime.shutdown()

+ 2 - 3
hivemind/server/checkpoint_saver.py

@@ -17,12 +17,11 @@ class CheckpointSaver(threading.Thread):
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
-        self.stop = False
+        self.stop = threading.Event()
 
     def run(self) -> None:
-        while not self.stop:
+        while not self.stop.wait(self.update_period):
             store_experts(self.expert_backends, self.checkpoint_dir)
-            time.sleep(self.update_period)
 
 
 def store_experts(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):

+ 3 - 4
hivemind/server/dht_handler.py

@@ -7,15 +7,14 @@ from hivemind.utils import Endpoint, get_port
 
 class DHTHandlerThread(threading.Thread):
     def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5):
-        super(DHTHandlerThread, self).__init__()
+        super().__init__()
         assert get_port(endpoint) is not None
         self.endpoint = endpoint
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
-        self.stop = False
+        self.stop = threading.Event()
 
     def run(self) -> None:
-        while not self.stop:
+        while not self.stop.wait(self.update_period):
             self.dht.declare_experts(self.experts.keys(), self.endpoint)
-            time.sleep(self.update_period)

+ 1 - 0
hivemind/server/runtime.py

@@ -77,6 +77,7 @@ class Runtime(threading.Thread):
         for pool in self.pools:
             if pool.is_alive():
                 pool.terminate()
+                pool.join()
 
     def iterate_minibatches_from_pools(self, timeout=None):
         """

+ 0 - 2
tests/test_moe.py

@@ -14,7 +14,6 @@ def test_moe():
                        for _ in range(20)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
                            num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
-
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
         # declare expert uids. Server *should* declare them by itself, but it takes time.
         assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
@@ -38,7 +37,6 @@ def test_call_many():
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
                            no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
-
         inputs = torch.randn(4, 64, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]

+ 3 - 5
tests/test_training.py

@@ -1,14 +1,12 @@
-import argparse
 from typing import Optional
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-
-from hivemind import RemoteExpert, find_open_port, LOCALHOST
+from sklearn.datasets import load_digits
 from test_utils.run_server import background_server
 
-from sklearn.datasets import load_digits
+from hivemind import RemoteExpert
 
 
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
@@ -30,7 +28,7 @@ def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: f
             loss.backward()
             opt.step()
 
-            accuracy = (outputs.argmax(dim=1) == y_train).numpy().mean()
+            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
             if accuracy >= threshold:
                 break
 

+ 18 - 14
tests/test_utils/run_server.py

@@ -10,6 +10,8 @@ import torch
 import hivemind
 from test_utils.layers import name_to_block, name_to_input
 
+logger = hivemind.get_logger(__name__)
+
 
 def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
@@ -49,20 +51,20 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None,
     dht = None
     if not no_dht:
         if not len(initial_peers):
-            print("No initial peers provided. Starting additional dht as an initial peer.")
+            logger.info("No initial peers provided. Starting additional dht as an initial peer.")
             dht_root = hivemind.DHT(initial_peers=initial_peers, start=True,
                                     listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}")
-            print(f"Initializing DHT with port {dht_root.port}")
+            logger.info(f"Initializing DHT with port {dht_root.port}")
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
-            print("Bootstrapping dht with peers:", initial_peers)
+            logger.info("Bootstrapping dht with peers:", initial_peers)
             if root_port is not None:
-                print(f"Warning: root_port={root_port} will not be used since we already have peers.")
+                logger.info(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
         dht = hivemind.DHT(initial_peers=initial_peers, start=True,
                            listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
         if verbose:
-            print(f"Running dht node on port {dht.port}")
+            logger.info(f"Running dht node on port {dht.port}")
 
     sample_input = name_to_input[expert_cls](4, hidden_dim)
     if isinstance(sample_input, tuple):
@@ -93,8 +95,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None,
     if start:
         server.run_in_background(await_ready=True)
         if verbose:
-            print(f"Server started at {server.listen_on}")
-            print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
+            logger.info(f"Server started at {server.listen_on}")
+            logger.info(f"Got {len(experts)} active experts of type {expert_cls}: {list(experts.keys())}")
     return server
 
 
@@ -110,14 +112,13 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
         yield pipe.recv()  # once the server is ready, runner will send us a tuple(hostname, port, dht port)
         pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
     finally:
-        try:
-            runner.join(timeout=shutdown_timeout)
-        finally:
+        runner.join(timeout=shutdown_timeout)
+        if runner.is_alive():
             if verbose:
-                print("Server failed to shutdown gracefully, terminating it the hard way...")
+                logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
             runner.kill()
             if verbose:
-                print("Server terminated.")
+                logger.info("Server terminated.")
 
 
 def _server_runner(pipe, *args, verbose, **kwargs):
@@ -131,13 +132,15 @@ def _server_runner(pipe, *args, verbose, **kwargs):
         pipe.recv()  # wait for shutdown signal
     finally:
         if verbose:
-            print("Shutting down server...")
+            logger.info("Shutting down server...")
         server.shutdown()
+        server.join()
         if verbose:
-            print("Server shut down successfully.")
+            logger.info("Server shut down successfully.")
 
 
 if __name__ == '__main__':
+    # fmt:off
     parser = argparse.ArgumentParser()
     parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
@@ -164,6 +167,7 @@ if __name__ == '__main__':
                         ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
                         ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    # fmt:on
 
     args = vars(parser.parse_args())