Bläddra i källkod

Improve error handling, remove deprecated functionality (#261)

- modified .load_state_from_peers to handle network errors and ensure mpfuture is always set
- removed dht.get/declare_experts from library, tests, benchmarks and examples
- removed default expiration=... parameter in hivemind.DHT
- removed stale TODOs and notifications
- rollback Runtime.stop to a shutdown pipe -> server no longer hangs on shutdown (also: repaired benchmark_throughput)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 år sedan
förälder
incheckning
f0c5627139

+ 3 - 2
benchmarks/benchmark_dht.py

@@ -25,7 +25,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
-                            expiration=expiration, listen_on=f'0.0.0.0:*')
+                            listen_on=f'0.0.0.0:*')
         peers.append(peer)
 
     store_peer, get_peer = peers[-2:]
@@ -43,7 +43,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1])
+        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
+                                            expiration=expiration)
         successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
 

+ 1 - 1
docs/user/quickstart.md

@@ -154,7 +154,7 @@ dht = hivemind.DHT(initial_peers=["localhost:1338"], listen=False, start=True)
 # note: listen=False means that your peer will operate in "client only" mode: 
 # this means that it can request other peers, but will not accept requests in return 
 
-expert1, expert4 = dht.get_experts(["expert.1", "expert.4"])
+expert1, expert4 = hivemind.get_experts(dht, ["expert.1", "expert.4"])
 assert expert1 is not None and expert4 is not None, "server hasn't declared experts (yet?)"
 ```
 

+ 43 - 40
hivemind/client/averaging/__init__.py

@@ -183,7 +183,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     self._port.value = found_port
                     await server.start()
                 else:
-                    logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+                    logger.debug(f"The averager is running in client mode.")
 
                 self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
                                                 client_mode=not self.listen)
@@ -422,47 +422,50 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return future.result() if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
-        key_manager = self._matchmaking.group_key_manager
-        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
-        peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
-                         if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
-
-        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
-            logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
-            future.set_result(None)
-            return
-
-        metadata = None
-        for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-            if peer != self.endpoint:
-                logger.info(f"Downloading parameters from peer {peer}")
-                stream = None
-                try:
-                    leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                    stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
-                    current_tensor_parts, tensors = [], []
-                    async for message in stream:
-                        if message.metadata:
-                            metadata = self.serializer.loads(message.metadata)
-                        if message.tensor_part.dtype and current_tensor_parts:
-                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+        try:
+            key_manager = self._matchmaking.group_key_manager
+            peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
+            peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
+                             if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
+
+            if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
+                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
+                future.set_result(None)
+                return
+
+            metadata = None
+            for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
+                if peer != self.endpoint:
+                    logger.info(f"Downloading parameters from peer {peer}")
+                    stream = None
+                    try:
+                        stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        current_tensor_parts, tensors = [], []
+                        async for message in stream:
+                            if message.metadata:
+                                metadata = self.serializer.loads(message.metadata)
+                            if message.tensor_part.dtype and current_tensor_parts:
+                                # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+                                tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                                current_tensor_parts = []
+                            current_tensor_parts.append(message.tensor_part)
+                        if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
-                            current_tensor_parts = []
-                        current_tensor_parts.append(message.tensor_part)
-                    if current_tensor_parts:
-                        tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
-                    future.set_result((metadata, tensors))
-                    self.last_updated = get_dht_time()
-                    return
-                except grpc.aio.AioRpcError as e:
-                    logger.info(f"Failed to download state from {peer} - {e}")
-                finally:
-                    if stream is not None:
-                        await stream.code()
+                        logger.info(f"Finished downloading state from {peer}")
+                        future.set_result((metadata, tensors))
+                        self.last_updated = get_dht_time()
+                        return
+                    except BaseException as e:
+                        logger.exception(f"Failed to download state from {peer} - {repr(e)}")
+                    finally:
+                        if stream is not None:
+                            await stream.code()
 
-        else:
-            logger.warning("Averager could not load state from peers: found no active peers.")
-            future.set_result(None)
+        finally:
+            if not future.done():
+                logger.warning("Averager could not load state from peers: all requests have failed.")
+                future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):
         """

+ 29 - 21
hivemind/client/beam_search.py

@@ -4,7 +4,7 @@ from collections import deque
 from functools import partial
 from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
 
-from hivemind.dht import DHT, DHTNode
+from hivemind.dht import DHT, DHTNode, DHTExpiration
 from hivemind.client.expert import RemoteExpert
 from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
                                         PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
@@ -22,7 +22,7 @@ class MoEBeamSearcher:
         * optional prefix that determines expert role, experiment name, etc.
         * one or more integers that determine that expert's position in an N-dimensional grid
 
-    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    A hivemind.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
     When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
     For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
     ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
@@ -63,8 +63,8 @@ class MoEBeamSearcher:
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
     """
 
-    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Tuple[int, ...],
-                 num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Sequence[int], num_workers: Optional[int] = None,
+                 negative_caching: bool = True, cache_expiration: DHTExpiration = 300, **kwargs):
         if not uid_prefix.endswith(UID_DELIMITER):
             uid_prefix += UID_DELIMITER
             logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
@@ -72,7 +72,8 @@ class MoEBeamSearcher:
         self.dht = dht
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
         self.total_grid_size = sum(grid_size)
-        self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
+        self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
+        self.num_workers, self.dht_kwargs = num_workers, kwargs
 
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
                          ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
@@ -84,12 +85,14 @@ class MoEBeamSearcher:
         """
         return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
                                               scores=tuple(scores), negative_caching=self.negative_caching,
-                                              num_workers=self.num_workers), return_future)
+                                              cache_expiration=self.cache_expiration, num_workers=self.num_workers),
+                                      return_future)
 
     @staticmethod
-    async def _get_initial_beam(dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int,
-                                scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None
-                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    async def _get_initial_beam(
+            dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None,
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
         num_workers = num_workers or dht.max_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
@@ -115,7 +118,7 @@ class MoEBeamSearcher:
                 elif maybe_prefix_data is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
                     asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + dht.default_expiration))
+                                                   expiration_time=get_dht_time() + cache_expiration))
 
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
@@ -137,12 +140,14 @@ class MoEBeamSearcher:
             assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
         return self.dht.run_coroutine(partial(
             self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
-            negative_caching=self.negative_caching, num_workers=self.num_workers), return_future=return_future)
+            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
+            num_workers=self.num_workers), return_future=return_future)
 
     @staticmethod
-    async def _get_active_successors(dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
-                                     negative_caching: bool, num_workers: Optional[int] = None
-                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    async def _get_active_successors(
+            dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float('inf')
         num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
@@ -157,7 +162,7 @@ class MoEBeamSearcher:
                 if found is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
                     asyncio.create_task(node.store(prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + dht.default_expiration))
+                                                   expiration_time=get_dht_time() + cache_expiration))
         return successors
 
     def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
@@ -176,14 +181,16 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
-                                              grid_scores=list(grid_scores), negative_caching=self.negative_caching,
-                                              num_workers=self.num_workers), return_future)
+        return self.dht.run_coroutine(partial(
+            self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size, grid_scores=list(grid_scores),
+            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
+            num_workers=self.num_workers), return_future)
 
     @classmethod
     async def _find_best_experts(
             cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            negative_caching: bool, num_workers: Optional[int] = None) -> List[RemoteExpert]:
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+    ) -> List[RemoteExpert]:
         num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -209,8 +216,9 @@ class MoEBeamSearcher:
             _, best_uid_prefixes = zip(*best_active_pairs)
 
             # search DHT for next step suffixes
-            successors = await cls._get_active_successors(dht, node, best_uid_prefixes, grid_size=None,
-                                                          negative_caching=negative_caching, num_workers=num_workers)
+            successors = await cls._get_active_successors(
+                dht, node, best_uid_prefixes, grid_size=None, negative_caching=negative_caching,
+                cache_expiration=cache_expiration, num_workers=num_workers)
             beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
             if not beam:
                 logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")

+ 1 - 12
hivemind/dht/__init__.py

@@ -51,13 +51,11 @@ class DHT(mp.Process):
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 expiration: float = 300, record_validators: Iterable[RecordValidatorBase] = (),
-                 **kwargs):
+                 record_validators: Iterable[RecordValidatorBase] = (), **kwargs):
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
-        self.default_expiration = expiration
         self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
@@ -257,12 +255,3 @@ class DHT(mp.Process):
         else:
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
-
-    def declare_experts(self, uids, endpoint, wait: bool = True):
-        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.")
-        return hivemind.declare_experts(self, uids, endpoint, wait=wait)
-
-    def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,
-                    return_future: bool = False) -> List[Optional[RemoteExpert]]:
-        logger.warning("dht.get_experts is scheduled for removal in 0.9.8, please use hivemind.get_experts.")
-        return hivemind.get_experts(self, uids, expiration_time, return_future)

+ 3 - 2
hivemind/server/__init__.py

@@ -67,7 +67,7 @@ class Server(threading.Thread):
 
         if self.dht and self.experts:
             self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
-                                                       update_period=self.update_period)
+                                                       update_period=self.update_period, daemon=True)
 
         if start:
             self.run_in_background(await_ready=True)
@@ -261,7 +261,8 @@ class Server(threading.Thread):
             self.dht.join()
 
         logger.debug(f"Shutting down runtime")
-        self.runtime.stop.set()
+
+        self.runtime.shutdown()
         logger.info("Server shutdown succesfully")
 
 

+ 9 - 7
hivemind/server/dht_handler.py

@@ -10,8 +10,8 @@ from hivemind.utils import Endpoint, get_dht_time, get_port
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5):
-        super().__init__()
+    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+        super().__init__(**kwargs)
         assert get_port(endpoint) is not None
         self.endpoint = endpoint
         self.experts = experts
@@ -25,7 +25,7 @@ class DHTHandlerThread(threading.Thread):
             declare_experts(self.dht, self.experts.keys(), self.endpoint)
 
 
-def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
+def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300,
                     wait: bool = True) -> Dict[ExpertUID, bool]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
@@ -33,18 +33,20 @@ def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
     :param uids: a list of expert ids to update
     :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
+    :param expiration: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint), return_future=not wait)
+    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration),
+                             return_future=not wait)
 
 
-async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint) -> Dict[ExpertUID, bool]:
+async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
+                           expiration: DHTExpiration) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
-    expiration_time = get_dht_time() + dht.default_expiration  # TODO use local expiration
+    expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
         data_to_store[uid, None] = endpoint

+ 2 - 1
hivemind/server/expert_uid.py

@@ -2,6 +2,7 @@ import random
 import re
 from typing import NamedTuple, Union, Tuple, Optional, List
 
+import hivemind
 from hivemind.dht import DHT
 from hivemind.utils import Endpoint, get_logger
 
@@ -81,7 +82,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
         # 2. look into DHT (if given) and remove duplicates
         if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
+            existing_expert_uids = {found_expert.uid for found_expert in hivemind.get_experts(dht, new_uids)
                                     if found_expert is not None}
             new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
 

+ 14 - 5
hivemind/server/runtime.py

@@ -41,6 +41,7 @@ class Runtime(threading.Thread):
 
     :param stats_report_interval: interval to collect and log statistics about runtime performance
     """
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
 
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
                  device: torch.device = None, stats_report_interval: Optional[int] = None):
@@ -48,8 +49,9 @@ class Runtime(threading.Thread):
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.shutdown_trigger = mp.Event()
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
-        self.stop = threading.Event()
 
         self.stats_report_interval = stats_report_interval
         if self.stats_report_interval is not None:
@@ -86,18 +88,18 @@ class Runtime(threading.Thread):
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
-                self.shutdown()
+                if not self.shutdown_trigger.is_set():
+                    self.shutdown()
 
     def shutdown(self):
         """ Gracefully terminate a running runtime. """
         logger.info("Shutting down")
+        self.ready.clear()
 
         if self.stats_report_interval is not None:
             self.stats_reporter.stop.set()
             self.stats_reporter.join()
 
-        self.stop.set()  # trigger background thread to shutdown
-
         logger.debug("Terminating pools")
         for pool in self.pools:
             if pool.is_alive():
@@ -105,6 +107,10 @@ class Runtime(threading.Thread):
                 pool.join()
         logger.debug("Pools terminated")
 
+        # trigger background thread to shutdown
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+        self.shutdown_trigger.set()
+
     def iterate_minibatches_from_pools(self, timeout=None):
         """
         Chooses pool according to priority, then copies exposed batch and frees the buffer
@@ -112,12 +118,15 @@ class Runtime(threading.Thread):
         with DefaultSelector() as selector:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
+            # selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
 
-            while not self.stop.is_set():
+            while True:
                 # wait until at least one batch_receiver becomes available
                 logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
 
                 logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)

+ 1 - 1
hivemind/server/task_pool.py

@@ -136,7 +136,7 @@ class TaskPool(TaskPoolBase):
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
 
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.name}_output')
+                                         name=f'{self.name}_output', daemon=True)
 
         try:
             output_thread.start()

+ 4 - 1
hivemind/utils/timed_storage.py

@@ -9,10 +9,11 @@ from dataclasses import dataclass
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
 get_dht_time = time.time  # a global (weakly synchronized) time
-MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes. Enforced when joining DHT.(TODO)
+MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes
 DHTExpiration = float
 ROOT = 0
 
+
 @dataclass(init=True, repr=True, frozen=True)
 class ValueWithExpiration(Generic[ValueType]):
     value: ValueType
@@ -37,11 +38,13 @@ class ValueWithExpiration(Generic[ValueType]):
         else:
             return False
 
+
 @dataclass(init=True, repr=True, order=True, frozen=True)
 class HeapEntry(Generic[KeyType]):
     expiration_time: DHTExpiration
     key: KeyType
 
+
 class TimedStorage(Generic[KeyType, ValueType]):
     """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
     frozen = False  # can be set to True. If true, do not remove outdated elements

+ 13 - 14
tests/test_dht_experts.py

@@ -6,7 +6,7 @@ import pytest
 
 import hivemind
 import hivemind.server.expert_uid
-from hivemind import LOCALHOST
+from hivemind import LOCALHOST, declare_experts, get_experts
 from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
 
@@ -26,13 +26,13 @@ def test_store_get_experts():
     for batch_start in range(0, len(expert_uids), batch_size):
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
 
-    found = other_peer.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    found = get_experts(other_peer, random.sample(expert_uids, 5) + ['foo', 'bar'])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
     other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
     hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
-    first_notfound, first_found = hivemind.get_experts(first_peer, ['foobar', other_expert])
+    first_notfound, first_found = get_experts(first_peer, ['foobar', other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.endpoint == f'that_host:{other_port}'
 
@@ -46,19 +46,18 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
     dht = []
     for i in range(dht_size):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+        dht.append(hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
 
     real_experts = sorted({
         'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
         for _ in range(total_experts)
     })
     for batch_start in range(0, len(real_experts), batch_size):
-        random.choice(dht).declare_experts(
-            real_experts[batch_start: batch_start + batch_size], wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+        declare_experts(random.choice(dht), real_experts[batch_start: batch_start + batch_size], wait=True,
+                        endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
 
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+    you = hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
 
     for i in range(10):
@@ -76,17 +75,17 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
 
 @pytest.mark.forked
 def test_dht_single_node():
-    node = hivemind.DHT(start=True, expiration=999)
+    node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
 
-    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
+    assert len(declare_experts(node, ['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
 
-    for expert in node.get_experts(['expert.3', 'expert.2']):
+    for expert in get_experts(node, ['expert.3', 'expert.2']):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
 
-    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
     found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
 

+ 5 - 5
tests/test_moe.py

@@ -15,7 +15,7 @@ def test_moe():
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
 
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
@@ -31,7 +31,7 @@ def test_no_experts():
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
 
         dmoe = hivemind.RemoteSwitchMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
@@ -119,8 +119,8 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True, expiration=999)
-    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
+    dht = hivemind.DHT(start=True)
+    assert all(hivemind.declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
 
     dmoe = hivemind.RemoteMixtureOfExperts(
         in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
@@ -208,7 +208,7 @@ def test_client_anomaly_detection():
 
     experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
 
-    dht = hivemind.DHT(start=True, expiration=999)
+    dht = hivemind.DHT(start=True)
     server = hivemind.Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:

+ 2 - 2
tests/test_training.py

@@ -48,7 +48,7 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
             as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = DHT(start=True, initial_peers=[dht_endpoint])
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -91,7 +91,7 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
                            num_handlers=1) as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = DHT(start=True, initial_peers=[dht_endpoint])
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)