Browse Source

delete trailing dot, fix naming in test

Pavel Samygin 3 years ago
parent
commit
8aa0b904a6
2 changed files with 11 additions and 9 deletions
  1. 2 2
      hivemind/moe/server/server.py
  2. 9 7
      tests/test_dht_experts.py

+ 2 - 2
hivemind/moe/server/server.py

@@ -321,7 +321,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, Li
         if runner.is_alive():
             logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
             runner.kill()
-            logger.info("Server terminated.")
+            logger.info("Server terminated")
 
 
 def _server_runner(pipe, *args, **kwargs):
@@ -341,7 +341,7 @@ def _server_runner(pipe, *args, **kwargs):
         logger.info("Shutting down server...")
         server.shutdown()
         server.join()
-        logger.info("Server shut down.")
+        logger.info("Server shut down")
 
 
 def _generate_uids(

+ 9 - 7
tests/test_dht_experts.py

@@ -52,22 +52,24 @@ def test_store_get_experts(n_peers=10):
 def test_beam_search(
     n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
 ):
-    dht = [hivemind.DHT(start=True)]
-    initial_peers = dht[0].get_visible_maddrs()
-    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    dht_instances = [hivemind.DHT(start=True)]
+    initial_peers = dht_instances[0].get_visible_maddrs()
+    dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     real_experts = sorted(
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
     )
     for batch_start in range(0, len(real_experts), batch_size):
-        dht_ = random.choice(dht)
+        dht = random.choice(dht_instances)
         declare_experts(
-            dht_,
+            dht,
             real_experts[batch_start : batch_start + batch_size],
-            peer_id=dht_.peer_id,
+            peer_id=dht.peer_id,
         )
 
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
+    neighbors = sum(
+        [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
+    )
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, "expert.", grid_dims)