|
@@ -1,3 +1,4 @@
|
|
|
+import asyncio
|
|
|
import random
|
|
|
import time
|
|
|
|
|
@@ -7,6 +8,7 @@ import torch
|
|
|
|
|
|
import hivemind
|
|
|
import hivemind.averaging.averager
|
|
|
+from conftest import cleanup_children
|
|
|
from hivemind.averaging.allreduce import AveragingMode
|
|
|
from hivemind.averaging.key_manager import GroupKeyManager
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
@@ -543,3 +545,20 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
|
|
|
for instance in [averager1, averager2] + dht_instances:
|
|
|
instance.shutdown()
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ with cleanup_children():
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ loop.run_until_complete(test_key_manager())
|
|
|
+ print(f"test_key_manager()")
|
|
|
+ del loop
|
|
|
+ for n_clients in [0, 1, 2]:
|
|
|
+ for n_aux in [0, 1, 2]:
|
|
|
+ with cleanup_children():
|
|
|
+ _test_allreduce_once(n_clients, n_aux)
|
|
|
+ print(f"_test_allreduce_once({n_clients}, {n_aux})")
|
|
|
+ for n_clients, n_aux in [(0, 4), (1, 3), (0, 3)]:
|
|
|
+ with cleanup_children():
|
|
|
+ _test_allreduce_once(n_clients, n_aux)
|
|
|
+ print(f"_test_allreduce_once({n_clients}, {n_aux})")
|
|
|
+ print("DONE!")
|