test_averaging.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import random
  2. import time
  3. import numpy as np
  4. import pytest
  5. import torch
  6. import hivemind
  7. import hivemind.averaging.averager
  8. from hivemind.averaging.allreduce import AveragingMode
  9. from hivemind.averaging.control import AveragingStage
  10. from hivemind.averaging.key_manager import GroupKeyManager
  11. from hivemind.averaging.load_balancing import load_balance_peers
  12. from hivemind.averaging.partition import AllreduceException
  13. from hivemind.p2p import PeerID
  14. from test_utils.dht_swarms import launch_dht_instances
  15. @pytest.mark.forked
  16. @pytest.mark.asyncio
  17. async def test_key_manager():
  18. dht = hivemind.DHT(start=True)
  19. key_manager = GroupKeyManager(
  20. dht,
  21. prefix="test_averaging",
  22. initial_group_bits="10110",
  23. target_group_size=2,
  24. )
  25. alice = dht.peer_id
  26. bob = PeerID(b"bob")
  27. t = hivemind.get_dht_time()
  28. key = key_manager.current_key
  29. await key_manager.declare_averager(key, alice, expiration_time=t + 60)
  30. await key_manager.declare_averager(key, bob, expiration_time=t + 61)
  31. q1 = await key_manager.get_averagers(key, only_active=True)
  32. await key_manager.declare_averager(key, alice, expiration_time=t + 66)
  33. q2 = await key_manager.get_averagers(key, only_active=True)
  34. await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
  35. q3 = await key_manager.get_averagers(key, only_active=True)
  36. q4 = await key_manager.get_averagers(key, only_active=False)
  37. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  38. assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
  39. assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
  40. assert len(q3) == 1 and (alice, t + 66) in q3
  41. assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
  42. assert len(q5) == 0
  43. dht.shutdown()
  44. def _test_allreduce_once(n_clients, n_aux):
  45. n_peers = 4
  46. modes = (
  47. [AveragingMode.CLIENT] * n_clients
  48. + [AveragingMode.AUX] * n_aux
  49. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  50. )
  51. random.shuffle(modes)
  52. tensors1 = [torch.randn(123), torch.zeros(3)]
  53. tensors2 = [torch.rand(123), torch.ones(3)]
  54. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  55. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  56. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  57. reference = [
  58. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  59. / max(1, n_peers - n_aux)
  60. for i in range(len(tensors1))
  61. ]
  62. dht_instances = launch_dht_instances(len(peer_tensors))
  63. averagers = [
  64. hivemind.averaging.DecentralizedAverager(
  65. tensors,
  66. dht=dht,
  67. target_group_size=4,
  68. averaging_expiration=15,
  69. prefix="mygroup",
  70. client_mode=mode == AveragingMode.CLIENT,
  71. auxiliary=mode == AveragingMode.AUX,
  72. start=True,
  73. )
  74. for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
  75. ]
  76. futures = []
  77. for averager in averagers:
  78. futures.append(averager.step(wait=False))
  79. for future in futures:
  80. result = future.result()
  81. for averager in averagers:
  82. assert averager.peer_id in result
  83. for averager in averagers:
  84. if averager.mode != AveragingMode.AUX:
  85. with averager.get_tensors() as averaged_tensors:
  86. for ref, our in zip(reference, averaged_tensors):
  87. assert torch.allclose(ref, our, atol=1e-6)
  88. for process in averagers + dht_instances:
  89. process.shutdown()
  90. @pytest.mark.forked
  91. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  92. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  93. def test_allreduce_once(n_clients, n_aux):
  94. _test_allreduce_once(n_clients, n_aux)
  95. @pytest.mark.forked
  96. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  97. def test_allreduce_once_edge_cases(n_clients, n_aux):
  98. _test_allreduce_once(n_clients, n_aux)
  99. @pytest.mark.forked
  100. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  101. n_peers = 4
  102. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  103. random.shuffle(client_modes)
  104. tensors1 = [torch.randn(123), torch.zeros(3)]
  105. tensors2 = [torch.rand(123), torch.ones(3)]
  106. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  107. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  108. dht_instances = launch_dht_instances(4)
  109. averagers = [
  110. hivemind.averaging.DecentralizedAverager(
  111. tensors,
  112. dht=dht,
  113. target_group_size=4,
  114. averaging_expiration=15,
  115. prefix="mygroup",
  116. client_mode=client_mode,
  117. start=True,
  118. )
  119. for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
  120. ]
  121. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  122. reference = [
  123. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  124. / sum(weights)
  125. for i in range(len(tensors1))
  126. ]
  127. futures = []
  128. for averager, weight in zip(averagers, weights):
  129. futures.append(averager.step(weight=weight, wait=False))
  130. for future in futures:
  131. future.result()
  132. for future, averager in zip(futures, averagers):
  133. with averager.get_tensors() as averaged_tensors:
  134. for ref, our in zip(reference, averaged_tensors):
  135. assert torch.allclose(ref, our, atol=1e-6)
  136. for process in averagers + dht_instances:
  137. process.shutdown()
  138. def compute_mean_std(averagers, unbiased=True):
  139. results = []
  140. for averager in averagers:
  141. with averager.get_tensors() as tensors:
  142. results.append([tensor.clone() for tensor in tensors])
  143. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  144. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  145. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  146. return means, stds
  147. @pytest.mark.forked
  148. def test_allreduce_grid():
  149. dht_instances = launch_dht_instances(8)
  150. averagers = [
  151. hivemind.averaging.DecentralizedAverager(
  152. averaged_tensors=[torch.randn(3)],
  153. dht=dht,
  154. target_group_size=2,
  155. prefix="mygroup",
  156. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  157. start=True,
  158. )
  159. for i, dht in enumerate(dht_instances)
  160. ]
  161. [means0], [stds0] = compute_mean_std(averagers)
  162. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  163. prev_means, prev_stds = means0, stds0
  164. for i in range(5):
  165. step_futures = [averager.step(wait=False) for averager in averagers]
  166. groups = [future.result() for future in step_futures]
  167. [means], [stds] = compute_mean_std(averagers)
  168. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  169. assert all(len(group) == 2 for group in groups)
  170. if i <= 2:
  171. assert torch.all(torch.le(stds, prev_stds))
  172. else:
  173. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  174. for process in averagers + dht_instances:
  175. process.shutdown()
  176. @pytest.mark.forked
  177. def test_allgather(n_averagers=8, target_group_size=4):
  178. dht_instances = launch_dht_instances(n_averagers)
  179. averagers = [
  180. hivemind.averaging.DecentralizedAverager(
  181. [torch.ones(1)],
  182. dht=dht,
  183. target_group_size=target_group_size,
  184. averaging_expiration=15,
  185. prefix="mygroup",
  186. initial_group_bits="000",
  187. start=True,
  188. )
  189. for dht in dht_instances
  190. ]
  191. futures = []
  192. for i, averager in enumerate(averagers):
  193. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  194. reference_metadata = {
  195. averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  196. }
  197. for future in futures:
  198. gathered = future.result()
  199. assert len(gathered) == target_group_size
  200. for peer_id in gathered:
  201. assert gathered[peer_id] == reference_metadata[peer_id]
  202. for process in averagers + dht_instances:
  203. process.shutdown()
  204. def get_cost(vector_size, partitions, bandwidths):
  205. return max(
  206. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  207. for i in range(len(partitions))
  208. )
  209. def check_optimality(vector_size, bandwidths, ref_partitions):
  210. partitions = list(load_balance_peers(vector_size, bandwidths))
  211. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  212. @pytest.mark.forked
  213. def test_load_balancing():
  214. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  215. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  216. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  217. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  218. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  219. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  220. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  221. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  222. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  223. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  224. assert load_balance_peers(100, (None, None)) == (50, 50)
  225. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  226. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  227. with pytest.raises(AssertionError):
  228. load_balance_peers(100, (0, 0, 0))
  229. for i in range(10):
  230. vector_size = np.random.randint(1, 1024**3)
  231. num_peers = np.random.randint(1, 256)
  232. scale = 1e-9 + np.random.rand() * 1e5
  233. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  234. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  235. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  236. assert np.sum(assignment) == vector_size
  237. assert np.min(assignment) >= 0
  238. @pytest.mark.forked
  239. def test_too_few_peers():
  240. dht_instances = launch_dht_instances(4)
  241. averagers = [
  242. hivemind.averaging.DecentralizedAverager(
  243. averaged_tensors=[torch.randn(3)],
  244. dht=dht,
  245. target_group_size=2,
  246. averaging_expiration=1,
  247. request_timeout=0.5,
  248. prefix="mygroup",
  249. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  250. start=True,
  251. )
  252. for i, dht in enumerate(dht_instances)
  253. ]
  254. step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
  255. for future in step_futures:
  256. with pytest.raises(AllreduceException):
  257. future.result()
  258. for process in averagers + dht_instances:
  259. process.shutdown()
  260. @pytest.mark.skip(
  261. reason="The current implementation of elasticity (multi-stage averaging when num_peers > ~3 * target_group_size) "
  262. "is incorrect (TODO @justheuristic)"
  263. )
  264. @pytest.mark.forked
  265. def test_overcrowded(num_peers=16):
  266. dht_instances = launch_dht_instances(num_peers)
  267. averagers = [
  268. hivemind.averaging.DecentralizedAverager(
  269. averaged_tensors=[torch.randn(3)],
  270. dht=dht,
  271. target_group_size=2,
  272. averaging_expiration=1,
  273. request_timeout=0.5,
  274. prefix="mygroup",
  275. initial_group_bits="",
  276. start=True,
  277. )
  278. for dht in dht_instances
  279. ]
  280. for _ in range(5):
  281. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  282. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  283. for process in averagers + dht_instances:
  284. process.shutdown()
  285. @pytest.mark.forked
  286. def test_load_state_from_peers():
  287. num_calls = 0
  288. super_metadata = dict(x=123)
  289. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  290. class TestAverager(hivemind.averaging.DecentralizedAverager):
  291. def get_current_state(self):
  292. """
  293. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  294. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  295. """
  296. nonlocal num_calls, super_metadata, super_tensors
  297. num_calls += 1
  298. return super_metadata, super_tensors
  299. dht_instances = launch_dht_instances(2)
  300. averager1 = TestAverager(
  301. [torch.randn(3), torch.rand(5)],
  302. dht=dht_instances[0],
  303. start=True,
  304. prefix="demo-run",
  305. target_group_size=2,
  306. )
  307. averager2 = TestAverager(
  308. [torch.randn(3), torch.rand(5)],
  309. dht=dht_instances[1],
  310. start=True,
  311. prefix="demo-run",
  312. target_group_size=2,
  313. )
  314. time.sleep(0.5)
  315. assert num_calls == 0
  316. got_metadata, got_tensors = averager2.load_state_from_peers()
  317. assert num_calls == 1
  318. assert got_metadata == super_metadata
  319. assert all(map(torch.allclose, got_tensors, super_tensors))
  320. super_metadata["y"] = 123
  321. super_tensors[1][2] = 9
  322. assert num_calls == 1
  323. assert got_metadata != super_metadata
  324. assert not all(map(torch.allclose, got_tensors, super_tensors))
  325. got_metadata, got_tensors = averager2.load_state_from_peers()
  326. assert num_calls == 2
  327. assert got_metadata == super_metadata
  328. assert all(map(torch.allclose, got_tensors, super_tensors))
  329. averager1.allow_state_sharing = False
  330. assert averager2.load_state_from_peers() is None
  331. averager1.allow_state_sharing = True
  332. time.sleep(0.5)
  333. got_metadata, got_tensors = averager2.load_state_from_peers()
  334. assert num_calls == 3
  335. assert got_metadata == super_metadata
  336. for instance in [averager1, averager2] + dht_instances:
  337. instance.shutdown()
  338. @pytest.mark.forked
  339. def test_load_state_priority():
  340. dht_instances = launch_dht_instances(4)
  341. averagers = []
  342. for i in range(4):
  343. averager = hivemind.DecentralizedAverager(
  344. [torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
  345. dht=dht_instances[i],
  346. start=True,
  347. prefix="demo-run",
  348. target_group_size=2,
  349. allow_state_sharing=i != 1,
  350. )
  351. averager.state_sharing_priority = 5 - abs(2 - i)
  352. averagers.append(averager)
  353. time.sleep(0.5)
  354. metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
  355. assert tensors[-1].item() == 2
  356. metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
  357. assert tensors[-1].item() == 3
  358. averagers[0].state_sharing_priority = 10
  359. time.sleep(0.2)
  360. metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
  361. assert tensors[-1].item() == 0
  362. averagers[1].allow_state_sharing = False
  363. averagers[2].allow_state_sharing = False
  364. metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
  365. assert tensors[-1].item() == 3
  366. for averager in averagers:
  367. averager.shutdown()
  368. for dht in dht_instances:
  369. dht.shutdown()
  370. @pytest.mark.forked
  371. def test_getset_bits():
  372. dht = hivemind.DHT(start=True)
  373. averager = hivemind.averaging.DecentralizedAverager(
  374. [torch.randn(3)],
  375. dht=dht,
  376. start=True,
  377. prefix="test_prefix",
  378. target_group_size=2,
  379. )
  380. averager.set_group_bits("00101011101010")
  381. assert averager.get_group_bits() == "00101011101010"
  382. @pytest.mark.forked
  383. def test_averaging_trigger():
  384. averagers = tuple(
  385. hivemind.averaging.DecentralizedAverager(
  386. averaged_tensors=[torch.randn(3)],
  387. dht=dht,
  388. min_matchmaking_time=0.5,
  389. request_timeout=0.3,
  390. prefix="mygroup",
  391. initial_group_bits="",
  392. start=True,
  393. )
  394. for dht in launch_dht_instances(4)
  395. )
  396. controls = []
  397. for i, averager in enumerate(averagers):
  398. controls.append(
  399. averager.step(
  400. wait=False,
  401. scheduled_time=hivemind.get_dht_time() + 0.5,
  402. weight=1.0,
  403. require_trigger=i in (1, 2),
  404. )
  405. )
  406. time.sleep(0.6)
  407. c0, c1, c2, c3 = controls
  408. assert not any(c.done() for c in controls)
  409. assert c0.stage == AveragingStage.RUNNING_ALLREDUCE
  410. assert c1.stage == AveragingStage.AWAITING_TRIGGER
  411. assert c2.stage == AveragingStage.AWAITING_TRIGGER
  412. assert c3.stage == AveragingStage.RUNNING_ALLREDUCE
  413. c1.allow_allreduce()
  414. c2.allow_allreduce()
  415. time.sleep(0.5)
  416. assert all(c.stage == AveragingStage.FINISHED for c in controls)
  417. assert all(c.done() for c in controls)
  418. # check that setting trigger twice does not raise error
  419. c0.allow_allreduce()
  420. @pytest.mark.forked
  421. def test_averaging_cancel():
  422. averagers = tuple(
  423. hivemind.averaging.DecentralizedAverager(
  424. averaged_tensors=[torch.randn(3)],
  425. dht=dht,
  426. min_matchmaking_time=0.5,
  427. request_timeout=0.3,
  428. client_mode=(i % 2 == 0),
  429. prefix="mygroup",
  430. start=True,
  431. )
  432. for i, dht in enumerate(launch_dht_instances(4))
  433. )
  434. step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
  435. time.sleep(0.1)
  436. step_controls[0].cancel()
  437. step_controls[1].cancel()
  438. for i, control in enumerate(step_controls):
  439. if i in (0, 1):
  440. assert control.cancelled()
  441. else:
  442. assert control.result() is not None and len(control.result()) == 2
  443. for averager in averagers:
  444. averager.shutdown()
  445. @pytest.mark.forked
  446. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  447. torch.manual_seed(42)
  448. dht_instances = launch_dht_instances(2)
  449. common_kwargs = {
  450. "start": True,
  451. "prefix": "demo-run",
  452. "target_group_size": 2,
  453. }
  454. x1 = torch.randn(n_dims, requires_grad=True)
  455. opt1 = torch.optim.Adam([x1], lr=0.05)
  456. averager1 = hivemind.TrainingAverager(
  457. opt1,
  458. average_gradients=True,
  459. average_parameters=True,
  460. average_opt_statistics=["exp_avg_sq"],
  461. dht=dht_instances[0],
  462. **common_kwargs
  463. )
  464. x2 = torch.randn(n_dims, requires_grad=True)
  465. opt2 = torch.optim.Adam([x2], lr=0.05)
  466. averager2 = hivemind.TrainingAverager(
  467. opt2,
  468. average_gradients=True,
  469. average_parameters=True,
  470. average_opt_statistics=["exp_avg_sq"],
  471. dht=dht_instances[1],
  472. **common_kwargs
  473. )
  474. a = torch.ones(n_dims)
  475. for i in range(n_steps):
  476. opt1.zero_grad()
  477. opt2.zero_grad()
  478. (x1 - a).pow(2).sum().backward()
  479. (x2 - a).pow(2).sum().backward()
  480. opt1.step()
  481. opt2.step()
  482. with torch.no_grad():
  483. x_avg = 0.5 * (x1 + x2)
  484. grad_avg = 0.5 * (x1.grad + x2.grad)
  485. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  486. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  487. f1 = averager1.step(wait=False)
  488. f2 = averager2.step(wait=False)
  489. f1.result()
  490. f2.result()
  491. assert torch.allclose(x1, x_avg)
  492. assert torch.allclose(x2, x_avg)
  493. assert torch.allclose(x1.grad, grad_avg)
  494. assert torch.allclose(x2.grad, grad_avg)
  495. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  496. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  497. for instance in [averager1, averager2] + dht_instances:
  498. instance.shutdown()