Selaa lähdekoodia

Continue fix test_averaging.py

Aleksandr Borzunov 4 vuotta sitten
vanhempi
commit
36282f81cc
2 muutettua tiedostoa jossa 7 lisäystä ja 2 poistoa
  1. 5 2
      tests/test_averaging.py
  2. 2 0
      tests/test_utils/dht_swarms.py

+ 5 - 2
tests/test_averaging.py

@@ -245,7 +245,7 @@ def test_allreduce_grid():
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
             start=True,
         )
-        for dht_instance in dhts
+        for i, dht_instance in enumerate(dhts)
     ]
 
     [means0], [stds0] = compute_mean_std(averagers)
@@ -289,7 +289,10 @@ def test_allgather():
     for i, averager in enumerate(averagers):
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
-    assert len(set(repr(sorted(future.result())) for future in futures)) == 2
+    gathered_data = [future.result() for future in futures]
+    gathered_data_reprs = [repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()}))
+                           for result in gathered_data]
+    assert len(set(gathered_data_reprs)) == 2
 
     reference_metadata = {
         averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)

+ 2 - 0
tests/test_utils/dht_swarms.py

@@ -90,6 +90,8 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
 
 
 def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
+    # TODO: Do it in parallel
+
     instances = [DHT(start=True, **kwargs)]
     initial_peers = instances[0].get_visible_maddrs()
     instances.extend(DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))