Browse Source

Improve error handling, remove deprecated functionality (#261)

- modified .load_state_from_peers to handle network errors and ensure mpfuture is always set
- removed dht.get/declare_experts from library, tests, benchmarks and examples
- removed default expiration=... parameter in hivemind.DHT
- removed stale TODOs and notifications
- rollback Runtime.stop to a shutdown pipe -> server no longer hangs on shutdown (also: repaired benchmark_throughput)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 years ago
parent
commit
f0c5627139

+ 3 - 2
benchmarks/benchmark_dht.py

@@ -25,7 +25,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for _ in trange(num_peers):
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
-                            expiration=expiration, listen_on=f'0.0.0.0:*')
+                            listen_on=f'0.0.0.0:*')
         peers.append(peer)
         peers.append(peer)
 
 
     store_peer, get_peer = peers[-2:]
     store_peer, get_peer = peers[-2:]
@@ -43,7 +43,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
         endpoints.append(random_endpoint())
-        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1])
+        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
+                                            expiration=expiration)
         successes = store_ok.values()
         successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
         total_store_time += time.perf_counter() - store_start
 
 

+ 1 - 1
docs/user/quickstart.md

@@ -154,7 +154,7 @@ dht = hivemind.DHT(initial_peers=["localhost:1338"], listen=False, start=True)
 # note: listen=False means that your peer will operate in "client only" mode: 
 # note: listen=False means that your peer will operate in "client only" mode: 
 # this means that it can request other peers, but will not accept requests in return 
 # this means that it can request other peers, but will not accept requests in return 
 
 
-expert1, expert4 = dht.get_experts(["expert.1", "expert.4"])
+expert1, expert4 = hivemind.get_experts(dht, ["expert.1", "expert.4"])
 assert expert1 is not None and expert4 is not None, "server hasn't declared experts (yet?)"
 assert expert1 is not None and expert4 is not None, "server hasn't declared experts (yet?)"
 ```
 ```
 
 

+ 43 - 40
hivemind/client/averaging/__init__.py

@@ -183,7 +183,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     self._port.value = found_port
                     self._port.value = found_port
                     await server.start()
                     await server.start()
                 else:
                 else:
-                    logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+                    logger.debug(f"The averager is running in client mode.")
 
 
                 self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
                 self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
                                                 client_mode=not self.listen)
                                                 client_mode=not self.listen)
@@ -422,47 +422,50 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _load_state_from_peers(self, future: MPFuture):
     async def _load_state_from_peers(self, future: MPFuture):
-        key_manager = self._matchmaking.group_key_manager
-        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
-        peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
-                         if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
-
-        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
-            logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
-            future.set_result(None)
-            return
-
-        metadata = None
-        for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-            if peer != self.endpoint:
-                logger.info(f"Downloading parameters from peer {peer}")
-                stream = None
-                try:
-                    leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                    stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
-                    current_tensor_parts, tensors = [], []
-                    async for message in stream:
-                        if message.metadata:
-                            metadata = self.serializer.loads(message.metadata)
-                        if message.tensor_part.dtype and current_tensor_parts:
-                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+        try:
+            key_manager = self._matchmaking.group_key_manager
+            peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
+            peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
+                             if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
+
+            if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
+                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
+                future.set_result(None)
+                return
+
+            metadata = None
+            for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
+                if peer != self.endpoint:
+                    logger.info(f"Downloading parameters from peer {peer}")
+                    stream = None
+                    try:
+                        stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        current_tensor_parts, tensors = [], []
+                        async for message in stream:
+                            if message.metadata:
+                                metadata = self.serializer.loads(message.metadata)
+                            if message.tensor_part.dtype and current_tensor_parts:
+                                # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
+                                tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+                                current_tensor_parts = []
+                            current_tensor_parts.append(message.tensor_part)
+                        if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
-                            current_tensor_parts = []
-                        current_tensor_parts.append(message.tensor_part)
-                    if current_tensor_parts:
-                        tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
-                    future.set_result((metadata, tensors))
-                    self.last_updated = get_dht_time()
-                    return
-                except grpc.aio.AioRpcError as e:
-                    logger.info(f"Failed to download state from {peer} - {e}")
-                finally:
-                    if stream is not None:
-                        await stream.code()
+                        logger.info(f"Finished downloading state from {peer}")
+                        future.set_result((metadata, tensors))
+                        self.last_updated = get_dht_time()
+                        return
+                    except BaseException as e:
+                        logger.exception(f"Failed to download state from {peer} - {repr(e)}")
+                    finally:
+                        if stream is not None:
+                            await stream.code()
 
 
-        else:
-            logger.warning("Averager could not load state from peers: found no active peers.")
-            future.set_result(None)
+        finally:
+            if not future.done():
+                logger.warning("Averager could not load state from peers: all requests have failed.")
+                future.set_result(None)
 
 
     def get_group_bits(self, wait: bool = True):
     def get_group_bits(self, wait: bool = True):
         """
         """

+ 29 - 21
hivemind/client/beam_search.py

@@ -4,7 +4,7 @@ from collections import deque
 from functools import partial
 from functools import partial
 from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
 from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
 
 
-from hivemind.dht import DHT, DHTNode
+from hivemind.dht import DHT, DHTNode, DHTExpiration
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.expert import RemoteExpert
 from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
 from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
                                         PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
                                         PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
@@ -22,7 +22,7 @@ class MoEBeamSearcher:
         * optional prefix that determines expert role, experiment name, etc.
         * optional prefix that determines expert role, experiment name, etc.
         * one or more integers that determine that expert's position in an N-dimensional grid
         * one or more integers that determine that expert's position in an N-dimensional grid
 
 
-    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    A hivemind.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
     When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
     When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
     For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
     For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
     ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
     ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
@@ -63,8 +63,8 @@ class MoEBeamSearcher:
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
     """
     """
 
 
-    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Tuple[int, ...],
-                 num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Sequence[int], num_workers: Optional[int] = None,
+                 negative_caching: bool = True, cache_expiration: DHTExpiration = 300, **kwargs):
         if not uid_prefix.endswith(UID_DELIMITER):
         if not uid_prefix.endswith(UID_DELIMITER):
             uid_prefix += UID_DELIMITER
             uid_prefix += UID_DELIMITER
             logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
             logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
@@ -72,7 +72,8 @@ class MoEBeamSearcher:
         self.dht = dht
         self.dht = dht
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
         self.uid_prefix, self.grid_size = uid_prefix, grid_size
         self.total_grid_size = sum(grid_size)
         self.total_grid_size = sum(grid_size)
-        self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
+        self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
+        self.num_workers, self.dht_kwargs = num_workers, kwargs
 
 
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
     def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
                          ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
                          ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
@@ -84,12 +85,14 @@ class MoEBeamSearcher:
         """
         """
         return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
         return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
                                               scores=tuple(scores), negative_caching=self.negative_caching,
                                               scores=tuple(scores), negative_caching=self.negative_caching,
-                                              num_workers=self.num_workers), return_future)
+                                              cache_expiration=self.cache_expiration, num_workers=self.num_workers),
+                                      return_future)
 
 
     @staticmethod
     @staticmethod
-    async def _get_initial_beam(dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int,
-                                scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None
-                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    async def _get_initial_beam(
+            dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None,
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
         num_workers = num_workers or dht.max_workers or beam_size
         num_workers = num_workers or dht.max_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
         unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
@@ -115,7 +118,7 @@ class MoEBeamSearcher:
                 elif maybe_prefix_data is None and negative_caching:
                 elif maybe_prefix_data is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
                     asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
                     asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + dht.default_expiration))
+                                                   expiration_time=get_dht_time() + cache_expiration))
 
 
             except asyncio.CancelledError:
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
                 for _, pending_task in pending_tasks:
@@ -137,12 +140,14 @@ class MoEBeamSearcher:
             assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
             assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
         return self.dht.run_coroutine(partial(
         return self.dht.run_coroutine(partial(
             self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
             self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
-            negative_caching=self.negative_caching, num_workers=self.num_workers), return_future=return_future)
+            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
+            num_workers=self.num_workers), return_future=return_future)
 
 
     @staticmethod
     @staticmethod
-    async def _get_active_successors(dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
-                                     negative_caching: bool, num_workers: Optional[int] = None
-                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    async def _get_active_successors(
+            dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float('inf')
         grid_size = grid_size or float('inf')
         num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
         num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
@@ -157,7 +162,7 @@ class MoEBeamSearcher:
                 if found is None and negative_caching:
                 if found is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
                     asyncio.create_task(node.store(prefix, subkey=-1, value=None,
                     asyncio.create_task(node.store(prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + dht.default_expiration))
+                                                   expiration_time=get_dht_time() + cache_expiration))
         return successors
         return successors
 
 
     def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
     def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
@@ -176,14 +181,16 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
-                                              grid_scores=list(grid_scores), negative_caching=self.negative_caching,
-                                              num_workers=self.num_workers), return_future)
+        return self.dht.run_coroutine(partial(
+            self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size, grid_scores=list(grid_scores),
+            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
+            num_workers=self.num_workers), return_future)
 
 
     @classmethod
     @classmethod
     async def _find_best_experts(
     async def _find_best_experts(
             cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
             cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            negative_caching: bool, num_workers: Optional[int] = None) -> List[RemoteExpert]:
+            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+    ) -> List[RemoteExpert]:
         num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
         num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -209,8 +216,9 @@ class MoEBeamSearcher:
             _, best_uid_prefixes = zip(*best_active_pairs)
             _, best_uid_prefixes = zip(*best_active_pairs)
 
 
             # search DHT for next step suffixes
             # search DHT for next step suffixes
-            successors = await cls._get_active_successors(dht, node, best_uid_prefixes, grid_size=None,
-                                                          negative_caching=negative_caching, num_workers=num_workers)
+            successors = await cls._get_active_successors(
+                dht, node, best_uid_prefixes, grid_size=None, negative_caching=negative_caching,
+                cache_expiration=cache_expiration, num_workers=num_workers)
             beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
             beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
             if not beam:
             if not beam:
                 logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
                 logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")

+ 1 - 12
hivemind/dht/__init__.py

@@ -51,13 +51,11 @@ class DHT(mp.Process):
 
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 expiration: float = 300, record_validators: Iterable[RecordValidatorBase] = (),
-                 **kwargs):
+                 record_validators: Iterable[RecordValidatorBase] = (), **kwargs):
         super().__init__()
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
-        self.default_expiration = expiration
         self._record_validator = CompositeValidator(record_validators)
         self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self._pipe, self.pipe = mp.Pipe(duplex=True)
@@ -257,12 +255,3 @@ class DHT(mp.Process):
         else:
         else:
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
                                             f" Please ensure the node is connected or specify peers=... manually."))
-
-    def declare_experts(self, uids, endpoint, wait: bool = True):
-        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.")
-        return hivemind.declare_experts(self, uids, endpoint, wait=wait)
-
-    def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,
-                    return_future: bool = False) -> List[Optional[RemoteExpert]]:
-        logger.warning("dht.get_experts is scheduled for removal in 0.9.8, please use hivemind.get_experts.")
-        return hivemind.get_experts(self, uids, expiration_time, return_future)

+ 3 - 2
hivemind/server/__init__.py

@@ -67,7 +67,7 @@ class Server(threading.Thread):
 
 
         if self.dht and self.experts:
         if self.dht and self.experts:
             self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
             self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
-                                                       update_period=self.update_period)
+                                                       update_period=self.update_period, daemon=True)
 
 
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
@@ -261,7 +261,8 @@ class Server(threading.Thread):
             self.dht.join()
             self.dht.join()
 
 
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
-        self.runtime.stop.set()
+
+        self.runtime.shutdown()
         logger.info("Server shutdown succesfully")
         logger.info("Server shutdown succesfully")
 
 
 
 

+ 9 - 7
hivemind/server/dht_handler.py

@@ -10,8 +10,8 @@ from hivemind.utils import Endpoint, get_dht_time, get_port
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5):
-        super().__init__()
+    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+        super().__init__(**kwargs)
         assert get_port(endpoint) is not None
         assert get_port(endpoint) is not None
         self.endpoint = endpoint
         self.endpoint = endpoint
         self.experts = experts
         self.experts = experts
@@ -25,7 +25,7 @@ class DHTHandlerThread(threading.Thread):
             declare_experts(self.dht, self.experts.keys(), self.endpoint)
             declare_experts(self.dht, self.experts.keys(), self.endpoint)
 
 
 
 
-def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
+def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300,
                     wait: bool = True) -> Dict[ExpertUID, bool]:
                     wait: bool = True) -> Dict[ExpertUID, bool]:
     """
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
     Make experts visible to all DHT peers; update timestamps if declared previously.
@@ -33,18 +33,20 @@ def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
     :param uids: a list of expert ids to update
     :param uids: a list of expert ids to update
     :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
+    :param expiration: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint), return_future=not wait)
+    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration),
+                             return_future=not wait)
 
 
 
 
-async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint) -> Dict[ExpertUID, bool]:
+async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
+                           expiration: DHTExpiration) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
     num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
-    expiration_time = get_dht_time() + dht.default_expiration  # TODO use local expiration
+    expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
         data_to_store[uid, None] = endpoint
         data_to_store[uid, None] = endpoint

+ 2 - 1
hivemind/server/expert_uid.py

@@ -2,6 +2,7 @@ import random
 import re
 import re
 from typing import NamedTuple, Union, Tuple, Optional, List
 from typing import NamedTuple, Union, Tuple, Optional, List
 
 
+import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.utils import Endpoint, get_logger
 from hivemind.utils import Endpoint, get_logger
 
 
@@ -81,7 +82,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
 
         # 2. look into DHT (if given) and remove duplicates
         # 2. look into DHT (if given) and remove duplicates
         if dht:
         if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
+            existing_expert_uids = {found_expert.uid for found_expert in hivemind.get_experts(dht, new_uids)
                                     if found_expert is not None}
                                     if found_expert is not None}
             new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
             new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
 
 

+ 14 - 5
hivemind/server/runtime.py

@@ -41,6 +41,7 @@ class Runtime(threading.Thread):
 
 
     :param stats_report_interval: interval to collect and log statistics about runtime performance
     :param stats_report_interval: interval to collect and log statistics about runtime performance
     """
     """
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
 
 
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
                  device: torch.device = None, stats_report_interval: Optional[int] = None):
                  device: torch.device = None, stats_report_interval: Optional[int] = None):
@@ -48,8 +49,9 @@ class Runtime(threading.Thread):
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.shutdown_trigger = mp.Event()
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
-        self.stop = threading.Event()
 
 
         self.stats_report_interval = stats_report_interval
         self.stats_report_interval = stats_report_interval
         if self.stats_report_interval is not None:
         if self.stats_report_interval is not None:
@@ -86,18 +88,18 @@ class Runtime(threading.Thread):
 
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
             finally:
-                self.shutdown()
+                if not self.shutdown_trigger.is_set():
+                    self.shutdown()
 
 
     def shutdown(self):
     def shutdown(self):
         """ Gracefully terminate a running runtime. """
         """ Gracefully terminate a running runtime. """
         logger.info("Shutting down")
         logger.info("Shutting down")
+        self.ready.clear()
 
 
         if self.stats_report_interval is not None:
         if self.stats_report_interval is not None:
             self.stats_reporter.stop.set()
             self.stats_reporter.stop.set()
             self.stats_reporter.join()
             self.stats_reporter.join()
 
 
-        self.stop.set()  # trigger background thread to shutdown
-
         logger.debug("Terminating pools")
         logger.debug("Terminating pools")
         for pool in self.pools:
         for pool in self.pools:
             if pool.is_alive():
             if pool.is_alive():
@@ -105,6 +107,10 @@ class Runtime(threading.Thread):
                 pool.join()
                 pool.join()
         logger.debug("Pools terminated")
         logger.debug("Pools terminated")
 
 
+        # trigger background thread to shutdown
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+        self.shutdown_trigger.set()
+
     def iterate_minibatches_from_pools(self, timeout=None):
     def iterate_minibatches_from_pools(self, timeout=None):
         """
         """
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         Chooses pool according to priority, then copies exposed batch and frees the buffer
@@ -112,12 +118,15 @@ class Runtime(threading.Thread):
         with DefaultSelector() as selector:
         with DefaultSelector() as selector:
             for pool in self.pools:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
+            # selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
 
 
-            while not self.stop.is_set():
+            while True:
                 # wait until at least one batch_receiver becomes available
                 # wait until at least one batch_receiver becomes available
                 logger.debug("Waiting for inputs from task pools")
                 logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
                 ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
 
 
                 logger.debug("Choosing the pool with highest priority")
                 logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
                 pool = max(ready_objects, key=lambda pool: pool.priority)

+ 1 - 1
hivemind/server/task_pool.py

@@ -136,7 +136,7 @@ class TaskPool(TaskPoolBase):
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
 
 
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.name}_output')
+                                         name=f'{self.name}_output', daemon=True)
 
 
         try:
         try:
             output_thread.start()
             output_thread.start()

+ 4 - 1
hivemind/utils/timed_storage.py

@@ -9,10 +9,11 @@ from dataclasses import dataclass
 KeyType = TypeVar('KeyType')
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
 ValueType = TypeVar('ValueType')
 get_dht_time = time.time  # a global (weakly synchronized) time
 get_dht_time = time.time  # a global (weakly synchronized) time
-MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes. Enforced when joining DHT.(TODO)
+MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes
 DHTExpiration = float
 DHTExpiration = float
 ROOT = 0
 ROOT = 0
 
 
+
 @dataclass(init=True, repr=True, frozen=True)
 @dataclass(init=True, repr=True, frozen=True)
 class ValueWithExpiration(Generic[ValueType]):
 class ValueWithExpiration(Generic[ValueType]):
     value: ValueType
     value: ValueType
@@ -37,11 +38,13 @@ class ValueWithExpiration(Generic[ValueType]):
         else:
         else:
             return False
             return False
 
 
+
 @dataclass(init=True, repr=True, order=True, frozen=True)
 @dataclass(init=True, repr=True, order=True, frozen=True)
 class HeapEntry(Generic[KeyType]):
 class HeapEntry(Generic[KeyType]):
     expiration_time: DHTExpiration
     expiration_time: DHTExpiration
     key: KeyType
     key: KeyType
 
 
+
 class TimedStorage(Generic[KeyType, ValueType]):
 class TimedStorage(Generic[KeyType, ValueType]):
     """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
     """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
     frozen = False  # can be set to True. If true, do not remove outdated elements
     frozen = False  # can be set to True. If true, do not remove outdated elements

+ 13 - 14
tests/test_dht_experts.py

@@ -6,7 +6,7 @@ import pytest
 
 
 import hivemind
 import hivemind
 import hivemind.server.expert_uid
 import hivemind.server.expert_uid
-from hivemind import LOCALHOST
+from hivemind import LOCALHOST, declare_experts, get_experts
 from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
 from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
 
 
@@ -26,13 +26,13 @@ def test_store_get_experts():
     for batch_start in range(0, len(expert_uids), batch_size):
     for batch_start in range(0, len(expert_uids), batch_size):
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
 
 
-    found = other_peer.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    found = get_experts(other_peer, random.sample(expert_uids, 5) + ['foo', 'bar'])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
 
     other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
     other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
     hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
     hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
-    first_notfound, first_found = hivemind.get_experts(first_peer, ['foobar', other_expert])
+    first_notfound, first_found = get_experts(first_peer, ['foobar', other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.endpoint == f'that_host:{other_port}'
     assert first_found.endpoint == f'that_host:{other_port}'
 
 
@@ -46,19 +46,18 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
     dht = []
     dht = []
     for i in range(dht_size):
     for i in range(dht_size):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+        dht.append(hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
 
 
     real_experts = sorted({
     real_experts = sorted({
         'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
         'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
         for _ in range(total_experts)
         for _ in range(total_experts)
     })
     })
     for batch_start in range(0, len(real_experts), batch_size):
     for batch_start in range(0, len(real_experts), batch_size):
-        random.choice(dht).declare_experts(
-            real_experts[batch_start: batch_start + batch_size], wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+        declare_experts(random.choice(dht), real_experts[batch_start: batch_start + batch_size], wait=True,
+                        endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
 
 
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+    you = hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
 
 
     for i in range(10):
     for i in range(10):
@@ -76,17 +75,17 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_dht_single_node():
 def test_dht_single_node():
-    node = hivemind.DHT(start=True, expiration=999)
+    node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
     beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
 
 
-    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
+    assert len(declare_experts(node, ['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
 
 
-    for expert in node.get_experts(['expert.3', 'expert.2']):
+    for expert in get_experts(node, ['expert.3', 'expert.2']):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
 
 
-    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
     found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
     found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
 
 

+ 5 - 5
tests/test_moe.py

@@ -15,7 +15,7 @@ def test_moe():
                        for _ in range(10)]
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
 
 
         dmoe = hivemind.RemoteMixtureOfExperts(
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
             in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
@@ -31,7 +31,7 @@ def test_no_experts():
                        for _ in range(10)]
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
                            hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
 
 
         dmoe = hivemind.RemoteSwitchMixtureOfExperts(
         dmoe = hivemind.RemoteSwitchMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
             in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
@@ -119,8 +119,8 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 @pytest.mark.forked
 def test_beam_search_correctness():
 def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True, expiration=999)
-    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
+    dht = hivemind.DHT(start=True)
+    assert all(hivemind.declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
 
 
     dmoe = hivemind.RemoteMixtureOfExperts(
     dmoe = hivemind.RemoteMixtureOfExperts(
         in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
         in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
@@ -208,7 +208,7 @@ def test_client_anomaly_detection():
 
 
     experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
     experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
 
 
-    dht = hivemind.DHT(start=True, expiration=999)
+    dht = hivemind.DHT(start=True)
     server = hivemind.Server(dht, experts, num_connection_handlers=1)
     server = hivemind.Server(dht, experts, num_connection_handlers=1)
     server.start()
     server.start()
     try:
     try:

+ 2 - 2
tests/test_training.py

@@ -48,7 +48,7 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
             as (server_endpoint, dht_endpoint):
             as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = DHT(start=True, initial_peers=[dht_endpoint])
 
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -91,7 +91,7 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
                            num_handlers=1) as (server_endpoint, dht_endpoint):
                            num_handlers=1) as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        dht = DHT(start=True, initial_peers=[dht_endpoint])
 
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)
         opt = SGD(model.parameters(), lr=0.05)