test_averaging.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. import asyncio
  2. import contextlib
  3. import gc
  4. import random
  5. import time
  6. from contextlib import suppress
  7. import numpy as np
  8. import psutil
  9. import pytest
  10. import torch
  11. import hivemind
  12. import hivemind.averaging.averager
  13. from hivemind import MPFuture, get_logger
  14. from hivemind.averaging.allreduce import AveragingMode
  15. from hivemind.averaging.key_manager import GroupKeyManager
  16. from hivemind.averaging.load_balancing import load_balance_peers
  17. from hivemind.p2p import PeerID
  18. from hivemind.proto.runtime_pb2 import CompressionType
  19. from test_utils.dht_swarms import launch_dht_instances
  20. logger = get_logger(__name__)
  21. @contextlib.contextmanager
  22. def cleanup_children():
  23. yield
  24. gc.collect() # Call .__del__() for removed objects
  25. children = psutil.Process().children(recursive=True)
  26. if children:
  27. logger.info(f"Cleaning up {len(children)} leftover child processes")
  28. for child in children:
  29. with suppress(psutil.NoSuchProcess):
  30. child.terminate()
  31. psutil.wait_procs(children, timeout=1)
  32. # Broken code or killing of child processes may leave the MPFuture backend corrupted
  33. MPFuture.reset_backend()
  34. @pytest.mark.forked
  35. @pytest.mark.asyncio
  36. async def test_key_manager():
  37. dht = hivemind.DHT(start=True)
  38. key_manager = GroupKeyManager(
  39. dht,
  40. prefix="test_averaging",
  41. initial_group_bits="10110",
  42. target_group_size=2,
  43. )
  44. alice = dht.peer_id
  45. bob = PeerID(b"bob")
  46. t = hivemind.get_dht_time()
  47. key = key_manager.current_key
  48. await key_manager.declare_averager(key, alice, expiration_time=t + 60)
  49. await key_manager.declare_averager(key, bob, expiration_time=t + 61)
  50. q1 = await key_manager.get_averagers(key, only_active=True)
  51. await key_manager.declare_averager(key, alice, expiration_time=t + 66)
  52. q2 = await key_manager.get_averagers(key, only_active=True)
  53. await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
  54. q3 = await key_manager.get_averagers(key, only_active=True)
  55. q4 = await key_manager.get_averagers(key, only_active=False)
  56. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  57. assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
  58. assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
  59. assert len(q3) == 1 and (alice, t + 66) in q3
  60. assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
  61. assert len(q5) == 0
  62. dht.shutdown()
  63. def _test_allreduce_once(n_clients, n_aux):
  64. n_peers = 4
  65. modes = (
  66. [AveragingMode.CLIENT] * n_clients
  67. + [AveragingMode.AUX] * n_aux
  68. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  69. )
  70. random.shuffle(modes)
  71. tensors1 = [torch.randn(123), torch.zeros(3)]
  72. tensors2 = [torch.rand(123), torch.ones(3)]
  73. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  74. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  75. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  76. reference = [
  77. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  78. / max(1, n_peers - n_aux)
  79. for i in range(len(tensors1))
  80. ]
  81. dht_instances = launch_dht_instances(len(peer_tensors))
  82. averagers = [
  83. hivemind.averaging.DecentralizedAverager(
  84. tensors,
  85. dht=dht,
  86. target_group_size=4,
  87. averaging_expiration=15,
  88. prefix="mygroup",
  89. client_mode=mode == AveragingMode.CLIENT,
  90. auxiliary=mode == AveragingMode.AUX,
  91. start=True,
  92. )
  93. for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
  94. ]
  95. futures = []
  96. for averager in averagers:
  97. futures.append(averager.step(wait=False))
  98. for future in futures:
  99. result = future.result()
  100. for averager in averagers:
  101. assert averager.endpoint in result
  102. for averager in averagers:
  103. if averager.mode != AveragingMode.AUX:
  104. with averager.get_tensors() as averaged_tensors:
  105. for ref, our in zip(reference, averaged_tensors):
  106. assert torch.allclose(ref, our, atol=1e-6)
  107. for instance in averagers + dht_instances:
  108. instance.shutdown()
  109. @pytest.mark.forked
  110. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  111. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  112. def test_allreduce_once(n_clients, n_aux):
  113. _test_allreduce_once(n_clients, n_aux)
  114. @pytest.mark.forked
  115. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  116. def test_allreduce_once_edge_cases(n_clients, n_aux):
  117. _test_allreduce_once(n_clients, n_aux)
  118. @pytest.mark.forked
  119. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  120. n_peers = 4
  121. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  122. random.shuffle(client_modes)
  123. tensors1 = [torch.randn(123), torch.zeros(3)]
  124. tensors2 = [torch.rand(123), torch.ones(3)]
  125. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  126. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  127. dht_instances = launch_dht_instances(4)
  128. averagers = [
  129. hivemind.averaging.DecentralizedAverager(
  130. tensors,
  131. dht=dht,
  132. target_group_size=4,
  133. averaging_expiration=15,
  134. prefix="mygroup",
  135. client_mode=client_mode,
  136. start=True,
  137. )
  138. for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
  139. ]
  140. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  141. reference = [
  142. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  143. / sum(weights)
  144. for i in range(len(tensors1))
  145. ]
  146. futures = []
  147. for averager, weight in zip(averagers, weights):
  148. futures.append(averager.step(weight=weight, wait=False))
  149. for future in futures:
  150. future.result()
  151. for future, averager in zip(futures, averagers):
  152. with averager.get_tensors() as averaged_tensors:
  153. for ref, our in zip(reference, averaged_tensors):
  154. assert torch.allclose(ref, our, atol=1e-6)
  155. for instance in averagers + dht_instances:
  156. instance.shutdown()
  157. @pytest.mark.forked
  158. def test_allreduce_compression():
  159. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  160. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  161. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  162. results = {}
  163. FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
  164. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  165. dht_instances = launch_dht_instances(2)
  166. averager1 = hivemind.averaging.DecentralizedAverager(
  167. [x.clone() for x in tensors1],
  168. dht=dht_instances[0],
  169. compression_type=compression_type_pair,
  170. client_mode=True,
  171. target_group_size=2,
  172. prefix="mygroup",
  173. start=True,
  174. )
  175. averager2 = hivemind.averaging.DecentralizedAverager(
  176. [x.clone() for x in tensors2],
  177. dht=dht_instances[1],
  178. compression_type=compression_type_pair,
  179. target_group_size=2,
  180. prefix="mygroup",
  181. start=True,
  182. )
  183. for future in averager1.step(wait=False), averager2.step(wait=False):
  184. future.result()
  185. with averager1.get_tensors() as averaged_tensors:
  186. results[compression_type_pair] = averaged_tensors
  187. for instance in [averager1, averager2] + dht_instances:
  188. instance.shutdown()
  189. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  190. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  191. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  192. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  193. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  194. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  195. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  196. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  197. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  198. for i in range(2):
  199. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  200. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  201. def compute_mean_std(averagers, unbiased=True):
  202. results = []
  203. for averager in averagers:
  204. with averager.get_tensors() as tensors:
  205. results.append([tensor.clone() for tensor in tensors])
  206. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  207. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  208. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  209. return means, stds
  210. @pytest.mark.forked
  211. def test_allreduce_grid():
  212. dht_instances = launch_dht_instances(8)
  213. averagers = [
  214. hivemind.averaging.DecentralizedAverager(
  215. averaged_tensors=[torch.randn(3)],
  216. dht=dht,
  217. target_group_size=2,
  218. prefix="mygroup",
  219. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  220. start=True,
  221. )
  222. for i, dht in enumerate(dht_instances)
  223. ]
  224. [means0], [stds0] = compute_mean_std(averagers)
  225. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  226. prev_means, prev_stds = means0, stds0
  227. for i in range(5):
  228. step_futures = [averager.step(wait=False) for averager in averagers]
  229. groups = [future.result() for future in step_futures]
  230. [means], [stds] = compute_mean_std(averagers)
  231. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  232. assert all(len(group) == 2 for group in groups)
  233. if i <= 2:
  234. assert torch.all(torch.le(stds, prev_stds))
  235. else:
  236. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  237. for averager in averagers + dht_instances:
  238. averager.shutdown()
  239. @pytest.mark.forked
  240. def test_allgather():
  241. dht_instances = launch_dht_instances(8)
  242. averagers = [
  243. hivemind.averaging.DecentralizedAverager(
  244. [torch.ones(1)],
  245. dht=dht,
  246. target_group_size=4,
  247. averaging_expiration=15,
  248. prefix="mygroup",
  249. initial_group_bits="000",
  250. start=True,
  251. )
  252. for dht in dht_instances
  253. ]
  254. futures = []
  255. for i, averager in enumerate(averagers):
  256. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  257. gathered_data = [future.result() for future in futures]
  258. gathered_data_reprs = [
  259. repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()})) for result in gathered_data
  260. ]
  261. assert len(set(gathered_data_reprs)) == 2
  262. reference_metadata = {
  263. averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  264. }
  265. for future in futures:
  266. gathered = future.result()
  267. assert len(gathered) == 4
  268. for endpoint in gathered:
  269. assert gathered[endpoint] == reference_metadata[endpoint]
  270. for averager in averagers + dht_instances:
  271. averager.shutdown()
  272. def get_cost(vector_size, partitions, bandwidths):
  273. return max(
  274. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  275. for i in range(len(partitions))
  276. )
  277. def check_optimality(vector_size, bandwidths, ref_partitions):
  278. partitions = list(load_balance_peers(vector_size, bandwidths))
  279. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  280. @pytest.mark.forked
  281. def test_load_balancing():
  282. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  283. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  284. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  285. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  286. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  287. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  288. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  289. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  290. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  291. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  292. assert load_balance_peers(100, (None, None)) == (50, 50)
  293. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  294. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  295. with pytest.raises(AssertionError):
  296. load_balance_peers(100, (0, 0, 0))
  297. for i in range(10):
  298. vector_size = np.random.randint(1, 1024 ** 3)
  299. num_peers = np.random.randint(1, 256)
  300. scale = 1e-9 + np.random.rand() * 1e5
  301. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  302. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  303. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  304. assert np.sum(assignment) == vector_size
  305. assert np.min(assignment) >= 0
  306. @pytest.mark.forked
  307. def test_too_few_peers():
  308. dht_instances = launch_dht_instances(4)
  309. averagers = [
  310. hivemind.averaging.DecentralizedAverager(
  311. averaged_tensors=[torch.randn(3)],
  312. dht=dht,
  313. target_group_size=2,
  314. averaging_expiration=1,
  315. request_timeout=0.5,
  316. prefix="mygroup",
  317. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  318. start=True,
  319. )
  320. for i, dht in enumerate(dht_instances)
  321. ]
  322. step_futures = [averager.step(wait=False) for averager in averagers]
  323. for future in step_futures:
  324. assert len(future.result()) == 2
  325. for averager in averagers + dht_instances:
  326. averager.shutdown()
  327. @pytest.mark.skip(
  328. reason="The current implementation of elasticity (multi-stage averaging for the case when "
  329. "num_peers > ~3 * target_group_size) is incorrect (TODO @justheuristic)"
  330. )
  331. @pytest.mark.forked
  332. def test_overcrowded(num_peers=16):
  333. dht_instances = launch_dht_instances(num_peers)
  334. averagers = [
  335. hivemind.averaging.DecentralizedAverager(
  336. averaged_tensors=[torch.randn(3)],
  337. dht=dht,
  338. target_group_size=2,
  339. averaging_expiration=1,
  340. request_timeout=0.5,
  341. prefix="mygroup",
  342. initial_group_bits="",
  343. start=True,
  344. )
  345. for dht in dht_instances
  346. ]
  347. for _ in range(5):
  348. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  349. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  350. for averager in averagers + dht_instances:
  351. averager.shutdown()
  352. @pytest.mark.forked
  353. def test_load_state_from_peers():
  354. num_calls = 0
  355. super_metadata = dict(x=123)
  356. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  357. class TestAverager(hivemind.averaging.DecentralizedAverager):
  358. def get_current_state(self):
  359. """
  360. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  361. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  362. """
  363. nonlocal num_calls, super_metadata, super_tensors
  364. num_calls += 1
  365. return super_metadata, super_tensors
  366. dht_instances = launch_dht_instances(2)
  367. averager1 = TestAverager(
  368. [torch.randn(3), torch.rand(5)],
  369. dht=dht_instances[0],
  370. start=True,
  371. prefix="demo-run",
  372. target_group_size=2,
  373. )
  374. dht_instances[1].get("demo-run.all_averagers")
  375. averager2 = TestAverager(
  376. [torch.randn(3), torch.rand(5)],
  377. dht=dht_instances[1],
  378. start=True,
  379. prefix="demo-run",
  380. target_group_size=2,
  381. )
  382. assert num_calls == 0
  383. got_metadata, got_tensors = averager2.load_state_from_peers()
  384. assert num_calls == 1
  385. assert got_metadata == super_metadata
  386. assert all(map(torch.allclose, got_tensors, super_tensors))
  387. super_metadata["y"] = 123
  388. super_tensors[1][2] = 9
  389. assert num_calls == 1
  390. assert got_metadata != super_metadata
  391. assert not all(map(torch.allclose, got_tensors, super_tensors))
  392. got_metadata, got_tensors = averager2.load_state_from_peers()
  393. assert num_calls == 2
  394. assert got_metadata == super_metadata
  395. assert all(map(torch.allclose, got_tensors, super_tensors))
  396. averager1.allow_state_sharing = False
  397. assert averager2.load_state_from_peers() is None
  398. averager1.allow_state_sharing = True
  399. got_metadata, got_tensors = averager2.load_state_from_peers()
  400. assert num_calls == 3
  401. assert got_metadata == super_metadata
  402. for instance in [averager1, averager2] + dht_instances:
  403. instance.shutdown()
  404. @pytest.mark.forked
  405. def test_getset_bits():
  406. dht = hivemind.DHT(start=True)
  407. averager = hivemind.averaging.DecentralizedAverager(
  408. [torch.randn(3)],
  409. dht=dht,
  410. start=True,
  411. prefix="test_prefix",
  412. target_group_size=2,
  413. )
  414. averager.set_group_bits("00101011101010")
  415. assert averager.get_group_bits() == "00101011101010"
  416. @pytest.mark.forked
  417. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  418. torch.manual_seed(42)
  419. dht_instances = launch_dht_instances(2)
  420. common_kwargs = {
  421. "start": True,
  422. "prefix": "demo-run",
  423. "target_group_size": 2,
  424. }
  425. x1 = torch.randn(n_dims, requires_grad=True)
  426. opt1 = torch.optim.Adam([x1], lr=0.05)
  427. averager1 = hivemind.averaging.TrainingAverager(
  428. opt1,
  429. average_gradients=True,
  430. average_parameters=True,
  431. average_opt_statistics=["exp_avg_sq"],
  432. dht=dht_instances[0],
  433. **common_kwargs
  434. )
  435. x2 = torch.randn(n_dims, requires_grad=True)
  436. opt2 = torch.optim.Adam([x2], lr=0.05)
  437. averager2 = hivemind.averaging.TrainingAverager(
  438. opt2,
  439. average_gradients=True,
  440. average_parameters=True,
  441. average_opt_statistics=["exp_avg_sq"],
  442. dht=dht_instances[1],
  443. **common_kwargs
  444. )
  445. a = torch.ones(n_dims)
  446. for i in range(n_steps):
  447. opt1.zero_grad()
  448. opt2.zero_grad()
  449. (x1 - a).pow(2).sum().backward()
  450. (x2 - a).pow(2).sum().backward()
  451. opt1.step()
  452. opt2.step()
  453. with torch.no_grad():
  454. x_avg = 0.5 * (x1 + x2)
  455. grad_avg = 0.5 * (x1.grad + x2.grad)
  456. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  457. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  458. f1 = averager1.step(wait=False)
  459. f2 = averager2.step(wait=False)
  460. f1.result()
  461. f2.result()
  462. assert torch.allclose(x1, x_avg)
  463. assert torch.allclose(x2, x_avg)
  464. assert torch.allclose(x1.grad, grad_avg)
  465. assert torch.allclose(x2.grad, grad_avg)
  466. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  467. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  468. for instance in [averager1, averager2] + dht_instances:
  469. instance.shutdown()
  470. if __name__ == '__main__':
  471. with cleanup_children():
  472. print(f"BEGAN test_key_manager()")
  473. loop = asyncio.new_event_loop()
  474. loop.run_until_complete(test_key_manager())
  475. print(f"PASSED test_key_manager()")
  476. del loop
  477. for n_clients in [0, 1, 2]:
  478. for n_aux in [0, 1, 2]:
  479. with cleanup_children():
  480. print(f"BEGAN _test_allreduce_once({n_clients}, {n_aux})")
  481. _test_allreduce_once(n_clients, n_aux)
  482. print(f"PASSED _test_allreduce_once({n_clients}, {n_aux})")
  483. for n_clients, n_aux in [(0, 4), (1, 3), (0, 3)]:
  484. with cleanup_children():
  485. print(f"BEGAN _test_allreduce_once({n_clients}, {n_aux})")
  486. _test_allreduce_once(n_clients, n_aux)
  487. print(f"PASSED _test_allreduce_once({n_clients}, {n_aux})")
  488. with cleanup_children():
  489. print(f"BEGAN _test_allreduce_weighted")
  490. test_allreduce_weighted()
  491. print(f"PASSED _test_allreduce_weighted")
  492. with cleanup_children():
  493. print(f"BEGAN _test_allreduce_grid")
  494. test_allreduce_grid()
  495. print(f"PASSED _test_allreduce_grid")
  496. with cleanup_children():
  497. print("BEGAN _test_allreduce_compression")
  498. test_allreduce_compression()
  499. print(f"PASSED _test_allreduce_compression")
  500. with cleanup_children():
  501. test_allgather()
  502. print(f"PASSED _test_allreduce_allgather")
  503. with cleanup_children():
  504. test_load_balancing()
  505. print(f"PASSED _test_loadbalancing")
  506. with cleanup_children():
  507. test_too_few_peers()
  508. print(f"PASSED _test_too_few")
  509. with cleanup_children():
  510. test_load_state_from_peers()
  511. print(f"PASSED _test_load_state")
  512. with cleanup_children():
  513. test_getset_bits()
  514. print(f"PASSED _test_getset")
  515. with cleanup_children():
  516. test_training_averager()
  517. print(f"PASSED _test_training")
  518. print("DONE!")