test_averaging.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. import random
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import hivemind
  6. import hivemind.averaging.averager
  7. from hivemind.averaging.allreduce import AveragingMode
  8. from hivemind.averaging.key_manager import GroupKeyManager
  9. from hivemind.averaging.load_balancing import load_balance_peers
  10. from hivemind.averaging.partition import AllreduceException
  11. from hivemind.p2p import PeerID
  12. from test_utils.dht_swarms import launch_dht_instances
  13. @pytest.mark.forked
  14. @pytest.mark.asyncio
  15. async def test_key_manager():
  16. dht = hivemind.DHT(start=True)
  17. key_manager = GroupKeyManager(
  18. dht,
  19. prefix="test_averaging",
  20. initial_group_bits="10110",
  21. target_group_size=2,
  22. )
  23. alice = dht.peer_id
  24. bob = PeerID(b"bob")
  25. t = hivemind.get_dht_time()
  26. key = key_manager.current_key
  27. await key_manager.declare_averager(key, alice, expiration_time=t + 60)
  28. await key_manager.declare_averager(key, bob, expiration_time=t + 61)
  29. q1 = await key_manager.get_averagers(key, only_active=True)
  30. await key_manager.declare_averager(key, alice, expiration_time=t + 66)
  31. q2 = await key_manager.get_averagers(key, only_active=True)
  32. await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
  33. q3 = await key_manager.get_averagers(key, only_active=True)
  34. q4 = await key_manager.get_averagers(key, only_active=False)
  35. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  36. assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
  37. assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
  38. assert len(q3) == 1 and (alice, t + 66) in q3
  39. assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
  40. assert len(q5) == 0
  41. dht.shutdown()
  42. def _test_allreduce_once(n_clients, n_aux):
  43. n_peers = 4
  44. modes = (
  45. [AveragingMode.CLIENT] * n_clients
  46. + [AveragingMode.AUX] * n_aux
  47. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  48. )
  49. random.shuffle(modes)
  50. tensors1 = [torch.randn(123), torch.zeros(3)]
  51. tensors2 = [torch.rand(123), torch.ones(3)]
  52. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  53. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  54. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  55. reference = [
  56. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  57. / max(1, n_peers - n_aux)
  58. for i in range(len(tensors1))
  59. ]
  60. dht_instances = launch_dht_instances(len(peer_tensors))
  61. averagers = [
  62. hivemind.averaging.DecentralizedAverager(
  63. tensors,
  64. dht=dht,
  65. target_group_size=4,
  66. averaging_expiration=15,
  67. prefix="mygroup",
  68. client_mode=mode == AveragingMode.CLIENT,
  69. auxiliary=mode == AveragingMode.AUX,
  70. start=True,
  71. )
  72. for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
  73. ]
  74. futures = []
  75. for averager in averagers:
  76. futures.append(averager.step(wait=False))
  77. for future in futures:
  78. result = future.result()
  79. for averager in averagers:
  80. assert averager.peer_id in result
  81. for averager in averagers:
  82. if averager.mode != AveragingMode.AUX:
  83. with averager.get_tensors() as averaged_tensors:
  84. for ref, our in zip(reference, averaged_tensors):
  85. assert torch.allclose(ref, our, atol=1e-6)
  86. for process in averagers + dht_instances:
  87. process.shutdown()
  88. @pytest.mark.forked
  89. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  90. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  91. def test_allreduce_once(n_clients, n_aux):
  92. _test_allreduce_once(n_clients, n_aux)
  93. @pytest.mark.forked
  94. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  95. def test_allreduce_once_edge_cases(n_clients, n_aux):
  96. _test_allreduce_once(n_clients, n_aux)
  97. @pytest.mark.forked
  98. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  99. n_peers = 4
  100. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  101. random.shuffle(client_modes)
  102. tensors1 = [torch.randn(123), torch.zeros(3)]
  103. tensors2 = [torch.rand(123), torch.ones(3)]
  104. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  105. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  106. dht_instances = launch_dht_instances(4)
  107. averagers = [
  108. hivemind.averaging.DecentralizedAverager(
  109. tensors,
  110. dht=dht,
  111. target_group_size=4,
  112. averaging_expiration=15,
  113. prefix="mygroup",
  114. client_mode=client_mode,
  115. start=True,
  116. )
  117. for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
  118. ]
  119. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  120. reference = [
  121. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  122. / sum(weights)
  123. for i in range(len(tensors1))
  124. ]
  125. futures = []
  126. for averager, weight in zip(averagers, weights):
  127. futures.append(averager.step(weight=weight, wait=False))
  128. for future in futures:
  129. future.result()
  130. for future, averager in zip(futures, averagers):
  131. with averager.get_tensors() as averaged_tensors:
  132. for ref, our in zip(reference, averaged_tensors):
  133. assert torch.allclose(ref, our, atol=1e-6)
  134. for process in averagers + dht_instances:
  135. process.shutdown()
  136. def compute_mean_std(averagers, unbiased=True):
  137. results = []
  138. for averager in averagers:
  139. with averager.get_tensors() as tensors:
  140. results.append([tensor.clone() for tensor in tensors])
  141. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  142. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  143. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  144. return means, stds
  145. @pytest.mark.forked
  146. def test_allreduce_grid():
  147. dht_instances = launch_dht_instances(8)
  148. averagers = [
  149. hivemind.averaging.DecentralizedAverager(
  150. averaged_tensors=[torch.randn(3)],
  151. dht=dht,
  152. target_group_size=2,
  153. prefix="mygroup",
  154. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  155. start=True,
  156. )
  157. for i, dht in enumerate(dht_instances)
  158. ]
  159. [means0], [stds0] = compute_mean_std(averagers)
  160. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  161. prev_means, prev_stds = means0, stds0
  162. for i in range(5):
  163. step_futures = [averager.step(wait=False) for averager in averagers]
  164. groups = [future.result() for future in step_futures]
  165. [means], [stds] = compute_mean_std(averagers)
  166. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  167. assert all(len(group) == 2 for group in groups)
  168. if i <= 2:
  169. assert torch.all(torch.le(stds, prev_stds))
  170. else:
  171. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  172. for process in averagers + dht_instances:
  173. process.shutdown()
  174. @pytest.mark.forked
  175. def test_allgather(n_averagers=8, target_group_size=4):
  176. dht_instances = launch_dht_instances(n_averagers)
  177. averagers = [
  178. hivemind.averaging.DecentralizedAverager(
  179. [torch.ones(1)],
  180. dht=dht,
  181. target_group_size=target_group_size,
  182. averaging_expiration=15,
  183. prefix="mygroup",
  184. initial_group_bits="000",
  185. start=True,
  186. )
  187. for dht in dht_instances
  188. ]
  189. futures = []
  190. for i, averager in enumerate(averagers):
  191. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  192. reference_metadata = {
  193. averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  194. }
  195. for future in futures:
  196. gathered = future.result()
  197. assert len(gathered) == target_group_size
  198. for peer_id in gathered:
  199. assert gathered[peer_id] == reference_metadata[peer_id]
  200. for process in averagers + dht_instances:
  201. process.shutdown()
  202. def get_cost(vector_size, partitions, bandwidths):
  203. return max(
  204. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  205. for i in range(len(partitions))
  206. )
  207. def check_optimality(vector_size, bandwidths, ref_partitions):
  208. partitions = list(load_balance_peers(vector_size, bandwidths))
  209. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  210. @pytest.mark.forked
  211. def test_load_balancing():
  212. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  213. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  214. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  215. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  216. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  217. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  218. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  219. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  220. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  221. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  222. assert load_balance_peers(100, (None, None)) == (50, 50)
  223. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  224. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  225. with pytest.raises(AssertionError):
  226. load_balance_peers(100, (0, 0, 0))
  227. for i in range(10):
  228. vector_size = np.random.randint(1, 1024 ** 3)
  229. num_peers = np.random.randint(1, 256)
  230. scale = 1e-9 + np.random.rand() * 1e5
  231. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  232. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  233. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  234. assert np.sum(assignment) == vector_size
  235. assert np.min(assignment) >= 0
  236. @pytest.mark.forked
  237. def test_too_few_peers():
  238. dht_instances = launch_dht_instances(4)
  239. averagers = [
  240. hivemind.averaging.DecentralizedAverager(
  241. averaged_tensors=[torch.randn(3)],
  242. dht=dht,
  243. target_group_size=2,
  244. averaging_expiration=1,
  245. request_timeout=0.5,
  246. prefix="mygroup",
  247. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  248. start=True,
  249. )
  250. for i, dht in enumerate(dht_instances)
  251. ]
  252. step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
  253. for future in step_futures:
  254. with pytest.raises(AllreduceException):
  255. future.result()
  256. for process in averagers + dht_instances:
  257. process.shutdown()
  258. @pytest.mark.skip(
  259. reason="The current implementation of elasticity (multi-stage averaging when num_peers > ~3 * target_group_size) "
  260. "is incorrect (TODO @justheuristic)"
  261. )
  262. @pytest.mark.forked
  263. def test_overcrowded(num_peers=16):
  264. dht_instances = launch_dht_instances(num_peers)
  265. averagers = [
  266. hivemind.averaging.DecentralizedAverager(
  267. averaged_tensors=[torch.randn(3)],
  268. dht=dht,
  269. target_group_size=2,
  270. averaging_expiration=1,
  271. request_timeout=0.5,
  272. prefix="mygroup",
  273. initial_group_bits="",
  274. start=True,
  275. )
  276. for dht in dht_instances
  277. ]
  278. for _ in range(5):
  279. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  280. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  281. for process in averagers + dht_instances:
  282. process.shutdown()
  283. @pytest.mark.forked
  284. def test_load_state_from_peers():
  285. num_calls = 0
  286. super_metadata = dict(x=123)
  287. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  288. class TestAverager(hivemind.averaging.DecentralizedAverager):
  289. def get_current_state(self):
  290. """
  291. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  292. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  293. """
  294. nonlocal num_calls, super_metadata, super_tensors
  295. num_calls += 1
  296. return super_metadata, super_tensors
  297. dht_instances = launch_dht_instances(2)
  298. averager1 = TestAverager(
  299. [torch.randn(3), torch.rand(5)],
  300. dht=dht_instances[0],
  301. start=True,
  302. prefix="demo-run",
  303. target_group_size=2,
  304. )
  305. dht_instances[1].get("demo-run.all_averagers")
  306. averager2 = TestAverager(
  307. [torch.randn(3), torch.rand(5)],
  308. dht=dht_instances[1],
  309. start=True,
  310. prefix="demo-run",
  311. target_group_size=2,
  312. )
  313. assert num_calls == 0
  314. got_metadata, got_tensors = averager2.load_state_from_peers()
  315. assert num_calls == 1
  316. assert got_metadata == super_metadata
  317. assert all(map(torch.allclose, got_tensors, super_tensors))
  318. super_metadata["y"] = 123
  319. super_tensors[1][2] = 9
  320. assert num_calls == 1
  321. assert got_metadata != super_metadata
  322. assert not all(map(torch.allclose, got_tensors, super_tensors))
  323. got_metadata, got_tensors = averager2.load_state_from_peers()
  324. assert num_calls == 2
  325. assert got_metadata == super_metadata
  326. assert all(map(torch.allclose, got_tensors, super_tensors))
  327. averager1.allow_state_sharing = False
  328. assert averager2.load_state_from_peers() is None
  329. averager1.allow_state_sharing = True
  330. got_metadata, got_tensors = averager2.load_state_from_peers()
  331. assert num_calls == 3
  332. assert got_metadata == super_metadata
  333. for instance in [averager1, averager2] + dht_instances:
  334. instance.shutdown()
  335. @pytest.mark.forked
  336. def test_getset_bits():
  337. dht = hivemind.DHT(start=True)
  338. averager = hivemind.averaging.DecentralizedAverager(
  339. [torch.randn(3)],
  340. dht=dht,
  341. start=True,
  342. prefix="test_prefix",
  343. target_group_size=2,
  344. )
  345. averager.set_group_bits("00101011101010")
  346. assert averager.get_group_bits() == "00101011101010"
  347. @pytest.mark.forked
  348. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  349. torch.manual_seed(42)
  350. dht_instances = launch_dht_instances(2)
  351. common_kwargs = {
  352. "start": True,
  353. "prefix": "demo-run",
  354. "target_group_size": 2,
  355. }
  356. x1 = torch.randn(n_dims, requires_grad=True)
  357. opt1 = torch.optim.Adam([x1], lr=0.05)
  358. averager1 = hivemind.averaging.TrainingAverager(
  359. opt1,
  360. average_gradients=True,
  361. average_parameters=True,
  362. average_opt_statistics=["exp_avg_sq"],
  363. dht=dht_instances[0],
  364. **common_kwargs
  365. )
  366. x2 = torch.randn(n_dims, requires_grad=True)
  367. opt2 = torch.optim.Adam([x2], lr=0.05)
  368. averager2 = hivemind.averaging.TrainingAverager(
  369. opt2,
  370. average_gradients=True,
  371. average_parameters=True,
  372. average_opt_statistics=["exp_avg_sq"],
  373. dht=dht_instances[1],
  374. **common_kwargs
  375. )
  376. a = torch.ones(n_dims)
  377. for i in range(n_steps):
  378. opt1.zero_grad()
  379. opt2.zero_grad()
  380. (x1 - a).pow(2).sum().backward()
  381. (x2 - a).pow(2).sum().backward()
  382. opt1.step()
  383. opt2.step()
  384. with torch.no_grad():
  385. x_avg = 0.5 * (x1 + x2)
  386. grad_avg = 0.5 * (x1.grad + x2.grad)
  387. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  388. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  389. f1 = averager1.step(wait=False)
  390. f2 = averager2.step(wait=False)
  391. f1.result()
  392. f2.result()
  393. assert torch.allclose(x1, x_avg)
  394. assert torch.allclose(x2, x_avg)
  395. assert torch.allclose(x1.grad, grad_avg)
  396. assert torch.allclose(x2.grad, grad_avg)
  397. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  398. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  399. for instance in [averager1, averager2] + dht_instances:
  400. instance.shutdown()