Эх сурвалжийг харах

Logging + got rid of batch_received (#52)

* Fix empty list of neighbors

* Add basic logger creation

* Introduce better logging, adjust tests a bit

* Add empty line

* Remove aiologger from requirements

* Revert test_dht

* Update tests

* More debug logging for backend
Max Ryabinin 5 жил өмнө
parent
commit
fe68aa1050

+ 1 - 1
hivemind/__init__.py

@@ -1,6 +1,6 @@
 from .client import *
 from .client import *
 from .dht import *
 from .dht import *
-from .server import *
+from .server import Server
 from .utils import *
 from .utils import *
 from .runtime import *
 from .runtime import *
 
 

+ 23 - 10
hivemind/dht/protocol.py

@@ -1,23 +1,30 @@
 from __future__ import annotations
 from __future__ import annotations
-import os
-import heapq
+
 import asyncio
 import asyncio
-import logging
+import heapq
+import os
 import urllib.parse
 import urllib.parse
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
 from warnings import warn
 from warnings import warn
+
+import grpc
+import grpc.experimental.aio
+
 from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
 from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from ..utils import Endpoint, compile_grpc
-import grpc, grpc.experimental.aio
+from ..utils import Endpoint, compile_grpc, get_logger
+
+logger = get_logger(__name__)
 
 
 with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
 with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
     dht_pb2, dht_grpc = compile_grpc(f_proto.read())
     dht_pb2, dht_grpc = compile_grpc(f_proto.read())
 
 
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
+    # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
     storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
+    # fmt:on
 
 
     @classmethod
     @classmethod
     async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
     async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -87,7 +94,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         try:
         try:
             peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
             peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
             peer_info = None
         responded = bool(peer_info and peer_info.node_id)
         responded = bool(peer_info and peer_info.node_id)
         peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
         peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
@@ -134,7 +141,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
             return response.store_ok
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
             return [False] * len(keys)
             return [False] * len(keys)
 
 
@@ -180,7 +187,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 output[key] = (value, expiration, nearest)
                 output[key] = (value, expiration, nearest)
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
 
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
@@ -197,8 +204,13 @@ class DHTProtocol(dht_grpc.DHTServicer):
             cached_value, cached_expiration = self.cache.get(key_id)
             cached_value, cached_expiration = self.cache.get(key_id)
             if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
             if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
                 maybe_value, maybe_expiration = cached_value, cached_expiration
                 maybe_value, maybe_expiration = cached_value, cached_expiration
-            peer_ids, endpoints = zip(*self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)))
+
+            nearest_neighbors = self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
+            if nearest_neighbors:
+                peer_ids, endpoints = zip(*nearest_neighbors)
+            else:
+                peer_ids, endpoints = [], []
 
 
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
             response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
             response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
@@ -247,6 +259,7 @@ _NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values
 
 
 class LocalStorage:
 class LocalStorage:
     """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
     """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
+
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.cache_size = maxsize or float("inf")
         self.data = dict()
         self.data = dict()

+ 12 - 3
hivemind/runtime/__init__.py

@@ -10,6 +10,9 @@ from prefetch_generator import BackgroundGenerator
 
 
 from .expert_backend import ExpertBackend
 from .expert_backend import ExpertBackend
 from .task_pool import TaskPool, TaskPoolBase
 from .task_pool import TaskPool, TaskPoolBase
+from hivemind.utils import get_logger
+
+logger = get_logger(__name__)
 
 
 
 
 class Runtime(threading.Thread):
 class Runtime(threading.Thread):
@@ -34,6 +37,7 @@ class Runtime(threading.Thread):
     :param device: if specified, moves all experts and data to this device via .to(device=device).
     :param device: if specified, moves all experts and data to this device via .to(device=device).
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
     """
     """
+
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
                  device: torch.device = None):
                  device: torch.device = None):
         super().__init__()
         super().__init__()
@@ -44,7 +48,6 @@ class Runtime(threading.Thread):
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
 
 
     def run(self):
     def run(self):
-        progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
         for pool in self.pools:
         for pool in self.pools:
             if not pool.is_alive():
             if not pool.is_alive():
                 pool.start()
                 pool.start()
@@ -55,13 +58,15 @@ class Runtime(threading.Thread):
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
             try:
                 self.ready.set()
                 self.ready.set()
+                logger.info("Started")
                 for pool, batch_index, batch in BackgroundGenerator(
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
                     outputs = pool.process_func(*batch)
                     outputs = pool.process_func(*batch)
+                    logger.info(f"Pool {pool.uid}: batch {batch_index} processed, size {outputs[0].size(0)}")
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
-                    progress.update(len(outputs[0]))
-                    progress.desc = f'pool.uid={pool.uid} batch_size={len(outputs[0])}'
             finally:
             finally:
+                logger.info("Shutting down")
                 self.shutdown()
                 self.shutdown()
 
 
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
@@ -85,12 +90,16 @@ class Runtime(threading.Thread):
 
 
             while True:
             while True:
                 # wait until at least one batch_receiver becomes available
                 # wait until at least one batch_receiver becomes available
+                logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
                 ready_objects = {key.data for (key, events) in ready_fds}
                 if self.SHUTDOWN_TRIGGER in ready_objects:
                 if self.SHUTDOWN_TRIGGER in ready_objects:
                     break  # someone asked us to shutdown, break from the loop
                     break  # someone asked us to shutdown, break from the loop
 
 
+                logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
                 pool = max(ready_objects, key=lambda pool: pool.priority)
 
 
+                logger.debug(f"Loading batch from {pool.uid}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                logger.debug(f"Loaded batch from {pool.uid}")
                 yield pool, batch_index, batch_tensors
                 yield pool, batch_index, batch_tensors

+ 14 - 13
hivemind/runtime/task_pool.py

@@ -14,8 +14,9 @@ from typing import List, Tuple, Dict, Any, Generator
 
 
 import torch
 import torch
 
 
-from ..utils import SharedFuture
+from hivemind.utils import SharedFuture, get_logger
 
 
+logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))
 
 
 
 
@@ -78,7 +79,6 @@ class TaskPool(TaskPoolBase):
 
 
         # interaction with Runtime
         # interaction with Runtime
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain batch inputs
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain batch inputs
-        self.batch_received = mp.Event()  # runtime can notify pool that it can send next batch
         self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain outputs
         self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain outputs
 
 
         if start:
         if start:
@@ -107,12 +107,11 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 batch = []
                 total_size = 0
                 total_size = 0
             try:
             try:
+                logger.debug(f"{self.uid} getting next task")
                 task = self.tasks.get(timeout=self.timeout)
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
             except Empty:
-                exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
-                for task in batch:
-                    task.future.set_exception(exc)
-                raise exc
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
 
 
             task_size = self.get_task_size(task)
             task_size = self.get_task_size(task)
 
 
@@ -126,10 +125,10 @@ class TaskPool(TaskPoolBase):
                 total_size += task_size
                 total_size += task_size
 
 
     def run(self, *args, **kwargs):
     def run(self, *args, **kwargs):
-        print(f'Starting pool, pid={os.getpid()}')
+        logger.info(f'{self.uid} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
         pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.uid}-pool_output_loop')
+                                         name=f'{self.uid}_output')
         try:
         try:
             output_thread.start()
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
             self._pool_input_loop(pending_batches, *args, **kwargs)
@@ -144,11 +143,8 @@ class TaskPool(TaskPoolBase):
         prev_num_tasks = 0  # number of tasks currently in shared buffer
         prev_num_tasks = 0  # number of tasks currently in shared buffer
         batch_index = max(pending_batches.keys(), default=0)
         batch_index = max(pending_batches.keys(), default=0)
         batch_iterator = self.iterate_minibatches(*args, **kwargs)
         batch_iterator = self.iterate_minibatches(*args, **kwargs)
-        self.batch_received.set()  # initial state: no batches/outputs pending
 
 
         while True:
         while True:
-            self.batch_received.wait()  # wait for runtime to receive (copy) previous batch
-
             # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
             # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
             # assumes that tasks are processed in the same order as they are created
             # assumes that tasks are processed in the same order as they are created
             for skip_i in range(prev_num_tasks):
             for skip_i in range(prev_num_tasks):
@@ -156,18 +152,21 @@ class TaskPool(TaskPoolBase):
                 if skip_i == prev_num_tasks - 1:
                 if skip_i == prev_num_tasks - 1:
                     self.priority = finished_task_timestamp
                     self.priority = finished_task_timestamp
 
 
+            logger.debug(f"{self.uid} getting next batch")
             batch_tasks = next(batch_iterator)
             batch_tasks = next(batch_iterator)
             # save batch futures, _output_loop will deliver on them later
             # save batch futures, _output_loop will deliver on them later
             pending_batches[batch_index] = batch_tasks
             pending_batches[batch_index] = batch_tasks
 
 
+            logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
             # find or create shared arrays for current batch size
             batch_inputs = [
             batch_inputs = [
                 torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
                 torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
                 for i in range(len(batch_tasks[0].args))
                 for i in range(len(batch_tasks[0].args))
             ]
             ]
 
 
-            self.batch_received.clear()  # sending next batch...
+            logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
             self.batch_sender.send((batch_index, batch_inputs))
             self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
             prev_num_tasks = len(batch_tasks)
             prev_num_tasks = len(batch_tasks)
             batch_index += 1
             batch_index += 1
 
 
@@ -175,16 +174,19 @@ class TaskPool(TaskPoolBase):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
 
         while True:
         while True:
+            logger.debug(f"{self.uid} waiting for results from runtime")
             payload = self.outputs_receiver.recv()
             payload = self.outputs_receiver.recv()
             if isinstance(payload, BaseException):
             if isinstance(payload, BaseException):
                 raise payload
                 raise payload
             else:
             else:
                 batch_index, batch_outputs = payload
                 batch_index, batch_outputs = payload
+            logger.debug(f"{self.uid}, batch {batch_index}: got results")
 
 
             # split batch into partitions for individual tasks
             # split batch into partitions for individual tasks
             batch_tasks = pending_batches.pop(batch_index)
             batch_tasks = pending_batches.pop(batch_index)
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
             outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
             outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
+            logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
 
 
             # dispatch results to futures
             # dispatch results to futures
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
@@ -200,7 +202,6 @@ class TaskPool(TaskPoolBase):
             raise TimeoutError()
             raise TimeoutError()
 
 
         batch_index, batch_inputs = self.batch_receiver.recv()
         batch_index, batch_inputs = self.batch_receiver.recv()
-        self.batch_received.set()  # pool can now prepare next batch
         batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
         batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
         return batch_index, batch_inputs
         return batch_index, batch_inputs
 
 

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,3 +7,4 @@ from .shared_future import *
 from .threading import *
 from .threading import *
 from .autograd import *
 from .autograd import *
 from .grpc import *
 from .grpc import *
+from .logging import get_logger

+ 18 - 0
hivemind/utils/logging.py

@@ -0,0 +1,18 @@
+import logging
+import os
+
+
+def get_logger(module_name: str) -> logging.Logger:
+    # trim package name
+    name_without_prefix = '.'.join(module_name.split('.')[1:])
+    loglevel = os.getenv('LOGLEVEL', 'INFO')
+
+    logging.addLevelName(logging.WARNING, 'WARN')
+    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}', style='{',
+                                  datefmt='%Y/%m/%d %H:%M:%S')
+    handler = logging.StreamHandler()
+    handler.setFormatter(formatter)
+    logger = logging.getLogger(name_without_prefix)
+    logger.setLevel(loglevel)
+    logger.addHandler(handler)
+    return logger

+ 1 - 2
requirements.txt

@@ -7,5 +7,4 @@ prefetch_generator>=1.0.1
 pytest
 pytest
 umsgpack
 umsgpack
 grpcio
 grpcio
-grpcio-tools>=1.30.0
-aiologger>=0.5.0
+grpcio-tools>=1.30.0

+ 8 - 7
tests/test_dht.py

@@ -31,7 +31,7 @@ def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
 
 
-def test_kademlia_protocol():
+def test_dht_protocol():
     # create the first peer
     # create the first peer
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
     peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
@@ -72,7 +72,7 @@ def test_kademlia_protocol():
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
 
 
             assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
             assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+                                                                          f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
 
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
             dummy_key = DHTID.generate()
@@ -250,15 +250,15 @@ def test_hivemind_dht():
 
 
 def test_store():
 def test_store():
     d = LocalStorage()
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 10)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
     print("Test store passed")
 
 
 
 
 def test_get_expired():
 def test_get_expired():
     d = LocalStorage()
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 1)
-    time.sleep(2)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
+    time.sleep(0.5)
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
@@ -271,9 +271,10 @@ def test_get_empty():
 
 
 def test_change_expiration_time():
 def test_change_expiration_time():
     d = LocalStorage()
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 2)
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
+    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
-    time.sleep(4)
+    time.sleep(1)
     assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
     assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
     print("Test change expiration time passed")
     print("Test change expiration time passed")
 
 

+ 2 - 2
tests/test_moe.py

@@ -18,7 +18,7 @@ def test_remote_module_call():
     logits = torch.randn(3, requires_grad=True)
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
     random_proj = torch.randn_like(xx)
 
 
-    with background_server(num_experts=num_experts, device='cpu',
+    with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
                            no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
                            no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
         experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
@@ -50,7 +50,7 @@ def test_determinism():
     xx = torch.randn(32, 1024, requires_grad=True)
     xx = torch.randn(32, 1024, requires_grad=True)
     mask = torch.randint(0, 1, (32, 1024))
     mask = torch.randint(0, 1, (32, 1024))
 
 
-    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
+    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
                            no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
                            no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
         expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
         expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)