test_averaging.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. import random
  2. import time
  3. import numpy as np
  4. import pytest
  5. import torch
  6. import hivemind
  7. import hivemind.averaging.averager
  8. from hivemind.averaging.allreduce import AveragingMode
  9. from hivemind.averaging.key_manager import GroupKeyManager
  10. from hivemind.averaging.load_balancing import load_balance_peers
  11. from hivemind.averaging.partition import AllreduceException
  12. from hivemind.p2p import PeerID
  13. from hivemind.proto.runtime_pb2 import CompressionType
  14. from test_utils.dht_swarms import launch_dht_instances
  15. @pytest.mark.forked
  16. @pytest.mark.asyncio
  17. async def test_key_manager():
  18. dht = hivemind.DHT(start=True)
  19. key_manager = GroupKeyManager(
  20. dht,
  21. prefix="test_averaging",
  22. initial_group_bits="10110",
  23. target_group_size=2,
  24. )
  25. alice = dht.peer_id
  26. bob = PeerID(b"bob")
  27. t = hivemind.get_dht_time()
  28. key = key_manager.current_key
  29. await key_manager.declare_averager(key, alice, expiration_time=t + 60)
  30. await key_manager.declare_averager(key, bob, expiration_time=t + 61)
  31. q1 = await key_manager.get_averagers(key, only_active=True)
  32. await key_manager.declare_averager(key, alice, expiration_time=t + 66)
  33. q2 = await key_manager.get_averagers(key, only_active=True)
  34. await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
  35. q3 = await key_manager.get_averagers(key, only_active=True)
  36. q4 = await key_manager.get_averagers(key, only_active=False)
  37. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  38. assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
  39. assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
  40. assert len(q3) == 1 and (alice, t + 66) in q3
  41. assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
  42. assert len(q5) == 0
  43. dht.shutdown()
  44. def _test_allreduce_once(n_clients, n_aux):
  45. n_peers = 4
  46. modes = (
  47. [AveragingMode.CLIENT] * n_clients
  48. + [AveragingMode.AUX] * n_aux
  49. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  50. )
  51. random.shuffle(modes)
  52. tensors1 = [torch.randn(123), torch.zeros(3)]
  53. tensors2 = [torch.rand(123), torch.ones(3)]
  54. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  55. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  56. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  57. reference = [
  58. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  59. / max(1, n_peers - n_aux)
  60. for i in range(len(tensors1))
  61. ]
  62. dht_instances = launch_dht_instances(len(peer_tensors))
  63. averagers = [
  64. hivemind.averaging.DecentralizedAverager(
  65. tensors,
  66. dht=dht,
  67. target_group_size=4,
  68. averaging_expiration=15,
  69. prefix="mygroup",
  70. client_mode=mode == AveragingMode.CLIENT,
  71. auxiliary=mode == AveragingMode.AUX,
  72. start=True,
  73. )
  74. for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
  75. ]
  76. futures = []
  77. for averager in averagers:
  78. futures.append(averager.step(wait=False))
  79. for future in futures:
  80. result = future.result()
  81. for averager in averagers:
  82. assert averager.peer_id in result
  83. for averager in averagers:
  84. if averager.mode != AveragingMode.AUX:
  85. with averager.get_tensors() as averaged_tensors:
  86. for ref, our in zip(reference, averaged_tensors):
  87. assert torch.allclose(ref, our, atol=1e-6)
  88. for process in averagers + dht_instances:
  89. process.shutdown()
  90. @pytest.mark.forked
  91. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  92. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  93. def test_allreduce_once(n_clients, n_aux):
  94. _test_allreduce_once(n_clients, n_aux)
  95. @pytest.mark.forked
  96. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  97. def test_allreduce_once_edge_cases(n_clients, n_aux):
  98. _test_allreduce_once(n_clients, n_aux)
  99. @pytest.mark.forked
  100. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  101. n_peers = 4
  102. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  103. random.shuffle(client_modes)
  104. tensors1 = [torch.randn(123), torch.zeros(3)]
  105. tensors2 = [torch.rand(123), torch.ones(3)]
  106. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  107. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  108. dht_instances = launch_dht_instances(4)
  109. averagers = [
  110. hivemind.averaging.DecentralizedAverager(
  111. tensors,
  112. dht=dht,
  113. target_group_size=4,
  114. averaging_expiration=15,
  115. prefix="mygroup",
  116. client_mode=client_mode,
  117. start=True,
  118. )
  119. for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
  120. ]
  121. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  122. reference = [
  123. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  124. / sum(weights)
  125. for i in range(len(tensors1))
  126. ]
  127. futures = []
  128. for averager, weight in zip(averagers, weights):
  129. futures.append(averager.step(weight=weight, wait=False))
  130. for future in futures:
  131. future.result()
  132. for future, averager in zip(futures, averagers):
  133. with averager.get_tensors() as averaged_tensors:
  134. for ref, our in zip(reference, averaged_tensors):
  135. assert torch.allclose(ref, our, atol=1e-6)
  136. for process in averagers + dht_instances:
  137. process.shutdown()
  138. @pytest.mark.forked
  139. def test_allreduce_compression():
  140. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  141. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  142. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  143. results = {}
  144. FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
  145. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  146. dht_instances = launch_dht_instances(2)
  147. averager1 = hivemind.averaging.DecentralizedAverager(
  148. [x.clone() for x in tensors1],
  149. dht=dht_instances[0],
  150. compression_type=compression_type_pair,
  151. client_mode=True,
  152. target_group_size=2,
  153. prefix="mygroup",
  154. start=True,
  155. )
  156. averager2 = hivemind.averaging.DecentralizedAverager(
  157. [x.clone() for x in tensors2],
  158. dht=dht_instances[1],
  159. compression_type=compression_type_pair,
  160. target_group_size=2,
  161. prefix="mygroup",
  162. start=True,
  163. )
  164. for future in averager1.step(wait=False), averager2.step(wait=False):
  165. future.result()
  166. with averager1.get_tensors() as averaged_tensors:
  167. results[compression_type_pair] = averaged_tensors
  168. for instance in [averager1, averager2] + dht_instances:
  169. instance.shutdown()
  170. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  171. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  172. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  173. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  174. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  175. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  176. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  177. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  178. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  179. for i in range(2):
  180. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  181. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  182. def compute_mean_std(averagers, unbiased=True):
  183. results = []
  184. for averager in averagers:
  185. with averager.get_tensors() as tensors:
  186. results.append([tensor.clone() for tensor in tensors])
  187. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  188. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  189. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  190. return means, stds
  191. @pytest.mark.forked
  192. def test_allreduce_grid():
  193. dht_instances = launch_dht_instances(8)
  194. averagers = [
  195. hivemind.averaging.DecentralizedAverager(
  196. averaged_tensors=[torch.randn(3)],
  197. dht=dht,
  198. target_group_size=2,
  199. prefix="mygroup",
  200. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  201. start=True,
  202. )
  203. for i, dht in enumerate(dht_instances)
  204. ]
  205. [means0], [stds0] = compute_mean_std(averagers)
  206. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  207. prev_means, prev_stds = means0, stds0
  208. for i in range(5):
  209. step_futures = [averager.step(wait=False) for averager in averagers]
  210. groups = [future.result() for future in step_futures]
  211. [means], [stds] = compute_mean_std(averagers)
  212. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  213. assert all(len(group) == 2 for group in groups)
  214. if i <= 2:
  215. assert torch.all(torch.le(stds, prev_stds))
  216. else:
  217. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  218. for process in averagers + dht_instances:
  219. process.shutdown()
  220. @pytest.mark.forked
  221. def test_allgather(n_averagers=8, target_group_size=4):
  222. dht_instances = launch_dht_instances(n_averagers)
  223. averagers = [
  224. hivemind.averaging.DecentralizedAverager(
  225. [torch.ones(1)],
  226. dht=dht,
  227. target_group_size=target_group_size,
  228. averaging_expiration=15,
  229. prefix="mygroup",
  230. initial_group_bits="000",
  231. start=True,
  232. )
  233. for dht in dht_instances
  234. ]
  235. futures = []
  236. for i, averager in enumerate(averagers):
  237. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  238. reference_metadata = {
  239. averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  240. }
  241. for future in futures:
  242. gathered = future.result()
  243. assert len(gathered) == target_group_size
  244. for peer_id in gathered:
  245. assert gathered[peer_id] == reference_metadata[peer_id]
  246. for process in averagers + dht_instances:
  247. process.shutdown()
  248. def get_cost(vector_size, partitions, bandwidths):
  249. return max(
  250. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  251. for i in range(len(partitions))
  252. )
  253. def check_optimality(vector_size, bandwidths, ref_partitions):
  254. partitions = list(load_balance_peers(vector_size, bandwidths))
  255. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  256. @pytest.mark.forked
  257. def test_load_balancing():
  258. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  259. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  260. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  261. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  262. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  263. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  264. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  265. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  266. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  267. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  268. assert load_balance_peers(100, (None, None)) == (50, 50)
  269. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  270. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  271. with pytest.raises(AssertionError):
  272. load_balance_peers(100, (0, 0, 0))
  273. for i in range(10):
  274. vector_size = np.random.randint(1, 1024 ** 3)
  275. num_peers = np.random.randint(1, 256)
  276. scale = 1e-9 + np.random.rand() * 1e5
  277. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  278. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  279. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  280. assert np.sum(assignment) == vector_size
  281. assert np.min(assignment) >= 0
  282. @pytest.mark.forked
  283. def test_too_few_peers():
  284. dht_instances = launch_dht_instances(4)
  285. averagers = [
  286. hivemind.averaging.DecentralizedAverager(
  287. averaged_tensors=[torch.randn(3)],
  288. dht=dht,
  289. target_group_size=2,
  290. averaging_expiration=1,
  291. request_timeout=0.5,
  292. prefix="mygroup",
  293. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  294. start=True,
  295. )
  296. for i, dht in enumerate(dht_instances)
  297. ]
  298. step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
  299. for future in step_futures:
  300. with pytest.raises(AllreduceException):
  301. future.result()
  302. for process in averagers + dht_instances:
  303. process.shutdown()
  304. @pytest.mark.skip(
  305. reason="The current implementation of elasticity (multi-stage averaging when num_peers > ~3 * target_group_size) "
  306. "is incorrect (TODO @justheuristic)"
  307. )
  308. @pytest.mark.forked
  309. def test_overcrowded(num_peers=16):
  310. dht_instances = launch_dht_instances(num_peers)
  311. averagers = [
  312. hivemind.averaging.DecentralizedAverager(
  313. averaged_tensors=[torch.randn(3)],
  314. dht=dht,
  315. target_group_size=2,
  316. averaging_expiration=1,
  317. request_timeout=0.5,
  318. prefix="mygroup",
  319. initial_group_bits="",
  320. start=True,
  321. )
  322. for dht in dht_instances
  323. ]
  324. for _ in range(5):
  325. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  326. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  327. for process in averagers + dht_instances:
  328. process.shutdown()
  329. @pytest.mark.forked
  330. def test_load_state_from_peers():
  331. num_calls = 0
  332. super_metadata = dict(x=123)
  333. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  334. class TestAverager(hivemind.averaging.DecentralizedAverager):
  335. def get_current_state(self):
  336. """
  337. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  338. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  339. """
  340. nonlocal num_calls, super_metadata, super_tensors
  341. num_calls += 1
  342. return super_metadata, super_tensors
  343. dht_instances = launch_dht_instances(2)
  344. averager1 = TestAverager(
  345. [torch.randn(3), torch.rand(5)],
  346. dht=dht_instances[0],
  347. start=True,
  348. prefix="demo-run",
  349. target_group_size=2,
  350. )
  351. dht_instances[1].get("demo-run.all_averagers")
  352. averager2 = TestAverager(
  353. [torch.randn(3), torch.rand(5)],
  354. dht=dht_instances[1],
  355. start=True,
  356. prefix="demo-run",
  357. target_group_size=2,
  358. )
  359. assert num_calls == 0
  360. got_metadata, got_tensors = averager2.load_state_from_peers()
  361. assert num_calls == 1
  362. assert got_metadata == super_metadata
  363. assert all(map(torch.allclose, got_tensors, super_tensors))
  364. super_metadata["y"] = 123
  365. super_tensors[1][2] = 9
  366. assert num_calls == 1
  367. assert got_metadata != super_metadata
  368. assert not all(map(torch.allclose, got_tensors, super_tensors))
  369. got_metadata, got_tensors = averager2.load_state_from_peers()
  370. assert num_calls == 2
  371. assert got_metadata == super_metadata
  372. assert all(map(torch.allclose, got_tensors, super_tensors))
  373. averager1.allow_state_sharing = False
  374. assert averager2.load_state_from_peers() is None
  375. averager1.allow_state_sharing = True
  376. got_metadata, got_tensors = averager2.load_state_from_peers()
  377. assert num_calls == 3
  378. assert got_metadata == super_metadata
  379. for instance in [averager1, averager2] + dht_instances:
  380. instance.shutdown()
  381. @pytest.mark.forked
  382. def test_getset_bits():
  383. dht = hivemind.DHT(start=True)
  384. averager = hivemind.averaging.DecentralizedAverager(
  385. [torch.randn(3)],
  386. dht=dht,
  387. start=True,
  388. prefix="test_prefix",
  389. target_group_size=2,
  390. )
  391. averager.set_group_bits("00101011101010")
  392. assert averager.get_group_bits() == "00101011101010"
  393. @pytest.mark.forked
  394. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  395. torch.manual_seed(42)
  396. dht_instances = launch_dht_instances(2)
  397. common_kwargs = {
  398. "start": True,
  399. "prefix": "demo-run",
  400. "target_group_size": 2,
  401. }
  402. x1 = torch.randn(n_dims, requires_grad=True)
  403. opt1 = torch.optim.Adam([x1], lr=0.05)
  404. averager1 = hivemind.averaging.TrainingAverager(
  405. opt1,
  406. average_gradients=True,
  407. average_parameters=True,
  408. average_opt_statistics=["exp_avg_sq"],
  409. dht=dht_instances[0],
  410. **common_kwargs
  411. )
  412. x2 = torch.randn(n_dims, requires_grad=True)
  413. opt2 = torch.optim.Adam([x2], lr=0.05)
  414. averager2 = hivemind.averaging.TrainingAverager(
  415. opt2,
  416. average_gradients=True,
  417. average_parameters=True,
  418. average_opt_statistics=["exp_avg_sq"],
  419. dht=dht_instances[1],
  420. **common_kwargs
  421. )
  422. a = torch.ones(n_dims)
  423. for i in range(n_steps):
  424. opt1.zero_grad()
  425. opt2.zero_grad()
  426. (x1 - a).pow(2).sum().backward()
  427. (x2 - a).pow(2).sum().backward()
  428. opt1.step()
  429. opt2.step()
  430. with torch.no_grad():
  431. x_avg = 0.5 * (x1 + x2)
  432. grad_avg = 0.5 * (x1.grad + x2.grad)
  433. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  434. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  435. f1 = averager1.step(wait=False)
  436. f2 = averager2.step(wait=False)
  437. f1.result()
  438. f2.result()
  439. assert torch.allclose(x1, x_avg)
  440. assert torch.allclose(x2, x_avg)
  441. assert torch.allclose(x1.grad, grad_avg)
  442. assert torch.allclose(x2.grad, grad_avg)
  443. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  444. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  445. for instance in [averager1, averager2] + dht_instances:
  446. instance.shutdown()