Browse Source

Fix device in Switch-MoE, overhaul Server architecture (#256)

* Set correct device for scores

* Put pipe_awaiter in a context manager

* Pass min_batch_size to ExpertBackend in Server.create

* Remove unneeded variable for exception in generate_uids_from_pattern

* Overhaul server architecture
Max Ryabinin 4 years ago
parent
commit
2328ba9262

+ 27 - 27
hivemind/client/averaging/__init__.py

@@ -171,35 +171,34 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
-
-        async def _run():
-            grpc.aio.init_grpc_aio()
-
-            if self.listen:
-                server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                found_port = server.add_insecure_port(self.listen_on)
-                assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                self._port.value = found_port
-                await server.start()
-            else:
-                logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                grpc.aio.init_grpc_aio()
+
+                if self.listen:
+                    server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                    found_port = server.add_insecure_port(self.listen_on)
+                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
+                    self._port.value = found_port
+                    await server.start()
+                else:
+                    logger.info(f"The averager running in an experimental client mode, please report any bugs.")
 
 
-            self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
-                                            client_mode=not self.listen)
-            if self.listen:
-                asyncio.create_task(self._declare_for_download_periodically())
+                self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
+                                                client_mode=not self.listen)
+                if self.listen:
+                    asyncio.create_task(self._declare_for_download_periodically())
 
 
-            self._pending_group_assembled = asyncio.Event()
-            self._pending_group_assembled.set()
-            self.ready.set()
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+                self.ready.set()
 
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(*args, **kwargs))
 
 
-        loop.run_until_complete(_run())
+            loop.run_until_complete(_run())
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -255,7 +254,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
                     data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout,
+                                                                        data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
                     group_id = group_info.group_id
                     group_id = group_info.group_id
@@ -294,7 +294,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         try:
         try:
             weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints,  map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
 
 
             # compute optimal part sizes from peer throughputs
             # compute optimal part sizes from peer throughputs
             incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
             incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]

+ 7 - 4
hivemind/client/moe.py

@@ -120,8 +120,11 @@ class RemoteMixtureOfExperts(nn.Module):
         batch_size = len(batch_experts)
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
+
+        device = grid_scores[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -133,11 +136,11 @@ class RemoteMixtureOfExperts(nn.Module):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
         return scores
 
 

+ 7 - 4
hivemind/client/switch_moe.py

@@ -156,8 +156,11 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         batch_size = len(batch_experts)
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]
+
+        device = grid_probs[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -169,10 +172,10 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
             for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
 
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
         return scores

+ 15 - 17
hivemind/dht/__init__.py

@@ -69,25 +69,23 @@ class DHT(mp.Process):
     def run(self) -> None:
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
 
-        async def _run():
-            node = await DHTNode.create(
-                initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-                num_workers=self.max_workers or 1, record_validator=self._record_validator,
-                **self.kwargs)
-            if node.port is not None:
-                self._port.value = node.port
-            self.ready.set()
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                node = await DHTNode.create(
+                    initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                    num_workers=self.max_workers or 1, record_validator=self._record_validator,
+                    **self.kwargs)
+                if node.port is not None:
+                    self._port.value = node.port
+                self.ready.set()
 
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
 
-        try:
-            loop.run_until_complete(_run())
-        except KeyboardInterrupt:
-            logger.debug("Caught KeyboardInterrupt, shutting down")
+            coro = _run()
+            loop.run_until_complete(coro)
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -96,7 +94,7 @@ class DHT(mp.Process):
         """
         """
         self.start()
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shut down a running dht process """
         """ Shut down a running dht process """

+ 3 - 1
hivemind/hivemind_cli/run_server.py

@@ -32,7 +32,9 @@ def main():
 
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
                         help='server will use this many processes to handle incoming requests')
-    parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all expert operations')
+    parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
                         help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')

+ 30 - 22
hivemind/server/__init__.py

@@ -65,16 +65,20 @@ class Server(threading.Thread):
             self.checkpoint_saver = None
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
         self.runtime = Runtime(self.experts, **kwargs)
 
 
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
+                                                       update_period=self.update_period)
+
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
     @classmethod
     @classmethod
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
-               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
-               *, start: bool, **kwargs) -> Server:
+               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
+               max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
+               checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
+               stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,6 +89,7 @@ class Server(threading.Thread):
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param hidden_dim: main dimension for expert_cls
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
 
 
@@ -112,9 +117,6 @@ class Server(threading.Thread):
         """
         """
         if custom_module_path is not None:
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
             add_custom_models_from_file(custom_module_path)
-
-        if len(kwargs) != 0:
-            logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
 
 
         if no_dht:
         if no_dht:
@@ -172,6 +174,7 @@ class Server(threading.Thread):
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_total_steps=num_total_steps,
                                                          num_total_steps=num_total_steps,
                                                          clip_grad_norm=clip_grad_norm,
                                                          clip_grad_norm=clip_grad_norm,
+                                                         min_batch_size=min_batch_size,
                                                          max_batch_size=max_batch_size)
                                                          max_batch_size=max_batch_size)
 
 
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
@@ -196,9 +199,7 @@ class Server(threading.Thread):
                 self.dht.run_in_background(await_ready=True)
                 self.dht.run_in_background(await_ready=True)
 
 
             if self.experts:
             if self.experts:
-                dht_handler_thread = DHTHandlerThread(
-                    experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
-                dht_handler_thread.start()
+                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
 
 
@@ -207,16 +208,10 @@ class Server(threading.Thread):
                 process.start()
                 process.start()
             process.ready.wait()
             process.ready.wait()
 
 
-        self.runtime.run()
-
-        for process in self.conn_handlers:
-            process.join()
-        if self.dht and self.experts:
-            dht_handler_thread.stop.set()
-            dht_handler_thread.join()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -242,19 +237,32 @@ class Server(threading.Thread):
 
 
     def shutdown(self):
     def shutdown(self):
         """
         """
-        Gracefully terminate a hivemind server, process-safe.
+        Gracefully terminate the server, process-safe.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
         self.ready.clear()
         self.ready.clear()
+
         for process in self.conn_handlers:
         for process in self.conn_handlers:
             process.terminate()
             process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
 
 
         if self.dht is not None:
         if self.dht is not None:
             self.dht.shutdown()
             self.dht.shutdown()
             self.dht.join()
             self.dht.join()
 
 
-        self.runtime.shutdown()
+        logger.debug(f"Shutting down runtime")
+        self.runtime.stop.set()
+        logger.info("Server shutdown succesfully")
 
 
 
 
 @contextmanager
 @contextmanager

+ 4 - 1
hivemind/server/connection_handler.py

@@ -52,7 +52,10 @@ class ConnectionHandler(mp.context.ForkProcess):
             await server.wait_for_termination()
             await server.wait_for_termination()
             logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
             logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
 
 
-        loop.run_until_complete(_run())
+        try:
+            loop.run_until_complete(_run())
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
 
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))

+ 2 - 2
hivemind/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
 
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
-        self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
-        self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
+        self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
+        self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)
 
 
         self.update_count = 0
         self.update_count = 0
         self.examples_processed = 0
         self.examples_processed = 0

+ 2 - 2
hivemind/server/expert_uid.py

@@ -62,8 +62,8 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                 else:
                 else:
                     raise ValueError("Block must be either fixed or a range [from:to]")
                     raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt as e:
-                raise e
+            except KeyboardInterrupt:
+                raise
             except Exception as e:
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
         return UID_DELIMITER.join(uid)
         return UID_DELIMITER.join(uid)

+ 17 - 20
hivemind/server/runtime.py

@@ -48,8 +48,8 @@ class Runtime(threading.Thread):
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
-        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+        self.stop = threading.Event()
 
 
         self.stats_report_interval = stats_report_interval
         self.stats_report_interval = stats_report_interval
         if self.stats_report_interval is not None:
         if self.stats_report_interval is not None:
@@ -72,62 +72,59 @@ class Runtime(threading.Thread):
 
 
                 for pool, batch_index, batch in BackgroundGenerator(
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
-                    logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
 
                     start = time()
                     start = time()
                     outputs = pool.process_func(*batch)
                     outputs = pool.process_func(*batch)
                     batch_processing_time = time() - start
                     batch_processing_time = time() - start
 
 
                     batch_size = outputs[0].size(0)
                     batch_size = outputs[0].size(0)
-                    logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
+                    logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
 
 
                     if self.stats_report_interval is not None:
                     if self.stats_report_interval is not None:
-                        self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+                        self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
 
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
             finally:
-                logger.info("Shutting down")
-
-                if self.stats_report_interval is not None:
-                    self.stats_reporter.stop.set()
-                    self.stats_reporter.join()
-
                 self.shutdown()
                 self.shutdown()
 
 
-    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
-
     def shutdown(self):
     def shutdown(self):
         """ Gracefully terminate a running runtime. """
         """ Gracefully terminate a running runtime. """
-        self.ready.clear()
-        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)  # trigger background thread to shutdown
+        logger.info("Shutting down")
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        self.stop.set()  # trigger background thread to shutdown
+
+        logger.debug("Terminating pools")
         for pool in self.pools:
         for pool in self.pools:
             if pool.is_alive():
             if pool.is_alive():
                 pool.terminate()
                 pool.terminate()
                 pool.join()
                 pool.join()
+        logger.debug("Pools terminated")
 
 
     def iterate_minibatches_from_pools(self, timeout=None):
     def iterate_minibatches_from_pools(self, timeout=None):
         """
         """
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         """
         """
         with DefaultSelector() as selector:
         with DefaultSelector() as selector:
-            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
             for pool in self.pools:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
 
 
-            while True:
+            while not self.stop.is_set():
                 # wait until at least one batch_receiver becomes available
                 # wait until at least one batch_receiver becomes available
                 logger.debug("Waiting for inputs from task pools")
                 logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
                 ready_objects = {key.data for (key, events) in ready_fds}
-                if self.SHUTDOWN_TRIGGER in ready_objects:
-                    break  # someone asked us to shutdown, break from the loop
 
 
                 logger.debug("Choosing the pool with highest priority")
                 logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
                 pool = max(ready_objects, key=lambda pool: pool.priority)
 
 
-                logger.debug(f"Loading batch from {pool.uid}")
+                logger.debug(f"Loading batch from {pool.name}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
-                logger.debug(f"Loaded batch from {pool.uid}")
+                logger.debug(f"Loaded batch from {pool.name}")
                 yield pool, batch_index, batch_tensors
                 yield pool, batch_index, batch_tensors
 
 
 
 

+ 60 - 70
hivemind/server/task_pool.py

@@ -6,7 +6,6 @@ import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 import time
 import time
-import uuid
 from abc import ABCMeta, abstractmethod
 from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from collections import namedtuple
 from concurrent.futures import Future
 from concurrent.futures import Future
@@ -24,8 +23,8 @@ Task = namedtuple("Task", ("future", "args"))
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
 
-    def __init__(self, process_func: callable, daemon=True):
-        super().__init__(daemon=daemon)
+    def __init__(self, process_func: callable, daemon=True, **kwargs):
+        super().__init__(daemon=daemon, **kwargs)
         self.process_func = process_func
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
 
@@ -63,19 +62,18 @@ class TaskPool(TaskPoolBase):
     :param process_func: function to be applied to every formed batch; called by Runtime
     :param process_func: function to be applied to every formed batch; called by Runtime
         Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
         Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+    :param name: pool name
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param timeout: wait for a subsequent task for at most this many seconds
     :param timeout: wait for a subsequent task for at most this many seconds
     :param pool_size: store at most this many unprocessed tasks in a queue
     :param pool_size: store at most this many unprocessed tasks in a queue
     :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
     :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
-    :param uid: pool identifier used for shared array allocation
     :param start: if True, start automatically at the end of __init__
     :param start: if True, start automatically at the end of __init__
     """
     """
 
 
-    def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
-                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
-        super().__init__(process_func, daemon=daemon)
+    def __init__(self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1,
+                 timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False):
+        super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
-        self.uid = uid or uuid.uuid4()
         self.prefetch_batches = prefetch_batches
         self.prefetch_batches = prefetch_batches
 
 
         # interaction with ConnectionHandlers
         # interaction with ConnectionHandlers
@@ -112,7 +110,7 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 batch = []
                 total_size = 0
                 total_size = 0
             try:
             try:
-                logger.debug(f"{self.uid} getting next task")
+                logger.debug(f"{self.name} getting next task")
                 task = self.tasks.get(timeout=self.timeout)
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
@@ -134,80 +132,72 @@ class TaskPool(TaskPoolBase):
 
 
     def run(self, *args, **kwargs):
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
         torch.set_num_threads(1)
-        logger.info(f'{self.uid} starting, pid={os.getpid()}')
+        logger.info(f'{self.name} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
+
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.uid}_output')
+                                         name=f'{self.name}_output')
+
         try:
         try:
             output_thread.start()
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
             self._pool_input_loop(pending_batches, *args, **kwargs)
-        except BaseException as e:
-            # terminate output loop
-            self.outputs_sender.send(e)
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
+        finally:
             output_thread.join()
             output_thread.join()
-            raise e
 
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
-        try:
-            prev_num_tasks = 0  # number of tasks currently in shared buffer
-            batch_index = max(pending_batches.keys(), default=0)
-            batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-            while True:
-                # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-                # assumes that tasks are processed in the same order as they are created
-                for skip_i in range(prev_num_tasks):
-                    finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
-                    if skip_i == prev_num_tasks - 1:
-                        self.priority = finished_task_timestamp
-
-                logger.debug(f"{self.uid} getting next batch")
-                batch_tasks = next(batch_iterator)
-                # save batch futures, _output_loop will deliver on them later
-                pending_batches[batch_index] = batch_tasks
-
-                logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
-                # find or create shared arrays for current batch size
-                batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
-                                range(len(batch_tasks[0].args))]
-                batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-                logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
-                self.batch_sender.send((batch_index, batch_inputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
-                prev_num_tasks = len(batch_tasks)
-                batch_index += 1
-        except KeyboardInterrupt:
-            logger.debug('Caught KeyboardInterrupt, shutting down')
+
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+        while True:
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = finished_task_timestamp
+
+            logger.debug(f"{self.name} getting next batch")
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            # find or create shared arrays for current batch size
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
+                            range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
+            self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
 
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
 
-        try:
-            while True:
-                logger.debug(f"{self.uid} waiting for results from runtime")
-                payload = self.outputs_receiver.recv()
-                if isinstance(payload, BaseException):
-                    raise payload
-                else:
-                    batch_index, batch_outputs = payload
-                logger.debug(f"{self.uid}, batch {batch_index}: got results")
-
-                # split batch into partitions for individual tasks
-                batch_tasks = pending_batches.pop(batch_index)
-                task_sizes = [self.get_task_size(task) for task in batch_tasks]
-                outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
-
-                # dispatch results to futures
-                for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                    try:
-                        task.future.set_result(tuple(task_outputs))
-                    except FutureStateError as e:
-                        logger.debug(f"Failed to send task result due to an exception: {e}")
-        except KeyboardInterrupt:
-            logger.debug(f"Caught KeyboardInterrupt, shutting down")
+        while True:
+            logger.debug(f"{self.name} waiting for results from runtime")
+            batch_index, batch_outputs = self.outputs_receiver.recv()
+            logger.debug(f"{self.name}, batch {batch_index}: got results")
+
+            # split batch into partitions for individual tasks
+            batch_tasks = pending_batches.pop(batch_index)
+            task_sizes = [self.get_task_size(task) for task in batch_tasks]
+            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
+
+            # dispatch results to futures
+            for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                try:
+                    task.future.set_result(tuple(task_outputs))
+                except FutureStateError as e:
+                    logger.debug(f"Failed to send task result due to an exception: {e}")
 
 
     @property
     @property
     def empty(self):
     def empty(self):