test_averaging.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import random
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import hivemind
  6. import hivemind.averaging.averager
  7. from hivemind.averaging.allreduce import AveragingMode
  8. from hivemind.averaging.key_manager import GroupKeyManager
  9. from hivemind.averaging.load_balancing import load_balance_peers
  10. from hivemind.proto.runtime_pb2 import CompressionType
  11. @pytest.mark.forked
  12. @pytest.mark.asyncio
  13. async def test_key_manager():
  14. key_manager = GroupKeyManager(
  15. hivemind.DHT(start=True),
  16. endpoint="localhvost",
  17. prefix="test_averaging",
  18. initial_group_bits="10110",
  19. target_group_size=2,
  20. )
  21. t = hivemind.get_dht_time()
  22. key = key_manager.current_key
  23. await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
  24. await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
  25. q1 = await key_manager.get_averagers(key, only_active=True)
  26. await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
  27. q2 = await key_manager.get_averagers(key, only_active=True)
  28. await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
  29. q3 = await key_manager.get_averagers(key, only_active=True)
  30. q4 = await key_manager.get_averagers(key, only_active=False)
  31. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  32. assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
  33. assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
  34. assert len(q3) == 1 and ("localhvost", t + 66) in q3
  35. assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
  36. assert len(q5) == 0
  37. def _test_allreduce_once(n_clients, n_aux):
  38. dht = hivemind.DHT(start=True)
  39. n_peers = 4
  40. modes = (
  41. [AveragingMode.CLIENT] * n_clients
  42. + [AveragingMode.AUX] * n_aux
  43. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  44. )
  45. random.shuffle(modes)
  46. tensors1 = [torch.randn(123), torch.zeros(3)]
  47. tensors2 = [torch.rand(123), torch.ones(3)]
  48. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  49. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  50. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  51. reference = [
  52. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  53. / max(1, n_peers - n_aux)
  54. for i in range(len(tensors1))
  55. ]
  56. averagers = [
  57. hivemind.averaging.DecentralizedAverager(
  58. tensors,
  59. dht=dht,
  60. target_group_size=4,
  61. averaging_expiration=15,
  62. prefix="mygroup",
  63. client_mode=mode == AveragingMode.CLIENT,
  64. listen_on="127.0.0.1:*",
  65. auxiliary=mode == AveragingMode.AUX,
  66. start=True,
  67. )
  68. for tensors, mode in zip(peer_tensors, modes)
  69. ]
  70. futures = []
  71. for averager in averagers:
  72. futures.append(averager.step(wait=False))
  73. for future in futures:
  74. result = future.result()
  75. for averager in averagers:
  76. assert averager.endpoint in result
  77. for averager in averagers:
  78. if averager.mode != AveragingMode.AUX:
  79. with averager.get_tensors() as averaged_tensors:
  80. for ref, our in zip(reference, averaged_tensors):
  81. assert torch.allclose(ref, our, atol=1e-6)
  82. for averager in averagers:
  83. averager.shutdown()
  84. dht.shutdown()
  85. @pytest.mark.forked
  86. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  87. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  88. def test_allreduce_once(n_clients, n_aux):
  89. _test_allreduce_once(n_clients, n_aux)
  90. @pytest.mark.forked
  91. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  92. def test_allreduce_once_edge_cases(n_clients, n_aux):
  93. _test_allreduce_once(n_clients, n_aux)
  94. @pytest.mark.forked
  95. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  96. dht = hivemind.DHT(start=True)
  97. n_peers = 4
  98. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  99. random.shuffle(client_modes)
  100. tensors1 = [torch.randn(123), torch.zeros(3)]
  101. tensors2 = [torch.rand(123), torch.ones(3)]
  102. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  103. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  104. averagers = [
  105. hivemind.averaging.DecentralizedAverager(
  106. tensors,
  107. dht=dht,
  108. target_group_size=4,
  109. averaging_expiration=15,
  110. prefix="mygroup",
  111. client_mode=client_mode,
  112. listen_on="127.0.0.1:*",
  113. start=True,
  114. )
  115. for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
  116. ]
  117. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  118. reference = [
  119. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  120. / sum(weights)
  121. for i in range(len(tensors1))
  122. ]
  123. futures = []
  124. for averager, weight in zip(averagers, weights):
  125. futures.append(averager.step(weight=weight, wait=False))
  126. for future in futures:
  127. future.result()
  128. for future, averager in zip(futures, averagers):
  129. with averager.get_tensors() as averaged_tensors:
  130. for ref, our in zip(reference, averaged_tensors):
  131. assert torch.allclose(ref, our, atol=1e-6)
  132. for averager in averagers:
  133. averager.shutdown()
  134. dht.shutdown()
  135. @pytest.mark.forked
  136. def test_allreduce_compression():
  137. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  138. dht = hivemind.DHT(start=True)
  139. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  140. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  141. results = {}
  142. FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
  143. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  144. averager1 = hivemind.averaging.DecentralizedAverager(
  145. [x.clone() for x in tensors1],
  146. dht=dht,
  147. compression_type=compression_type_pair,
  148. client_mode=True,
  149. target_group_size=2,
  150. prefix="mygroup",
  151. start=True,
  152. )
  153. averager2 = hivemind.averaging.DecentralizedAverager(
  154. [x.clone() for x in tensors2],
  155. dht=dht,
  156. compression_type=compression_type_pair,
  157. target_group_size=2,
  158. prefix="mygroup",
  159. listen_on="127.0.0.1:*",
  160. start=True,
  161. )
  162. for future in averager1.step(wait=False), averager2.step(wait=False):
  163. future.result()
  164. with averager1.get_tensors() as averaged_tensors:
  165. results[compression_type_pair] = averaged_tensors
  166. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  167. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  168. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  169. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  170. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  171. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  172. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  173. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  174. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  175. for i in range(2):
  176. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  177. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  178. def compute_mean_std(averagers, unbiased=True):
  179. results = []
  180. for averager in averagers:
  181. with averager.get_tensors() as tensors:
  182. results.append([tensor.clone() for tensor in tensors])
  183. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  184. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  185. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  186. return means, stds
  187. @pytest.mark.forked
  188. def test_allreduce_grid():
  189. dht = hivemind.DHT(start=True)
  190. averagers = [
  191. hivemind.averaging.DecentralizedAverager(
  192. averaged_tensors=[torch.randn(3)],
  193. dht=dht,
  194. target_group_size=2,
  195. prefix="mygroup",
  196. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  197. listen_on="127.0.0.1:*",
  198. start=True,
  199. )
  200. for i in range(8)
  201. ]
  202. [means0], [stds0] = compute_mean_std(averagers)
  203. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  204. prev_means, prev_stds = means0, stds0
  205. for i in range(5):
  206. step_futures = [averager.step(wait=False) for averager in averagers]
  207. groups = [future.result() for future in step_futures]
  208. [means], [stds] = compute_mean_std(averagers)
  209. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  210. assert all(len(group) == 2 for group in groups)
  211. if i <= 2:
  212. assert torch.all(torch.le(stds, prev_stds))
  213. else:
  214. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  215. for averager in averagers:
  216. averager.shutdown()
  217. dht.shutdown()
  218. @pytest.mark.forked
  219. def test_allgather():
  220. dht = hivemind.DHT(start=True)
  221. averagers = [
  222. hivemind.averaging.DecentralizedAverager(
  223. [torch.ones(1)],
  224. dht=dht,
  225. target_group_size=4,
  226. averaging_expiration=15,
  227. prefix="mygroup",
  228. initial_group_bits="000",
  229. listen_on="127.0.0.1:*",
  230. start=True,
  231. )
  232. for _ in range(8)
  233. ]
  234. futures = []
  235. for i, averager in enumerate(averagers):
  236. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  237. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  238. reference_metadata = {
  239. averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  240. }
  241. for future in futures:
  242. gathered = future.result()
  243. assert len(gathered) == 4
  244. for endpoint in gathered:
  245. assert gathered[endpoint] == reference_metadata[endpoint]
  246. for averager in averagers:
  247. averager.shutdown()
  248. dht.shutdown()
  249. def get_cost(vector_size, partitions, bandwidths):
  250. return max(
  251. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  252. for i in range(len(partitions))
  253. )
  254. def check_optimality(vector_size, bandwidths, ref_partitions):
  255. partitions = list(load_balance_peers(vector_size, bandwidths))
  256. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  257. @pytest.mark.forked
  258. def test_load_balancing():
  259. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  260. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  261. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  262. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  263. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  264. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  265. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  266. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  267. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  268. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  269. assert load_balance_peers(100, (None, None)) == (50, 50)
  270. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  271. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  272. with pytest.raises(AssertionError):
  273. load_balance_peers(100, (0, 0, 0))
  274. for i in range(10):
  275. vector_size = np.random.randint(1, 1024 ** 3)
  276. num_peers = np.random.randint(1, 256)
  277. scale = 1e-9 + np.random.rand() * 1e5
  278. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  279. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  280. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  281. assert np.sum(assignment) == vector_size
  282. assert np.min(assignment) >= 0
  283. @pytest.mark.forked
  284. def test_too_few_peers():
  285. dht = hivemind.DHT(start=True)
  286. averagers = [
  287. hivemind.averaging.DecentralizedAverager(
  288. averaged_tensors=[torch.randn(3)],
  289. dht=dht,
  290. target_group_size=2,
  291. averaging_expiration=1,
  292. request_timeout=0.5,
  293. prefix="mygroup",
  294. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  295. listen_on="127.0.0.1:*",
  296. start=True,
  297. )
  298. for i in range(4)
  299. ]
  300. step_futures = [averager.step(wait=False) for averager in averagers]
  301. for future in step_futures:
  302. assert len(future.result()) == 2
  303. for averager in averagers:
  304. averager.shutdown()
  305. dht.shutdown()
  306. @pytest.mark.forked
  307. def test_overcrowded(num_peers=16):
  308. dht = hivemind.DHT(start=True)
  309. averagers = [
  310. hivemind.averaging.DecentralizedAverager(
  311. averaged_tensors=[torch.randn(3)],
  312. dht=dht,
  313. target_group_size=2,
  314. averaging_expiration=1,
  315. request_timeout=0.5,
  316. prefix="mygroup",
  317. initial_group_bits="",
  318. listen_on="127.0.0.1:*",
  319. start=True,
  320. )
  321. for _ in range(num_peers)
  322. ]
  323. for t in range(5):
  324. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  325. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  326. for averager in averagers:
  327. averager.shutdown()
  328. dht.shutdown()
  329. @pytest.mark.forked
  330. def test_load_state_from_peers():
  331. num_calls = 0
  332. super_metadata = dict(x=123)
  333. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  334. class TestAverager(hivemind.averaging.DecentralizedAverager):
  335. def get_current_state(self):
  336. """
  337. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  338. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  339. """
  340. nonlocal num_calls, super_metadata, super_tensors
  341. num_calls += 1
  342. return super_metadata, super_tensors
  343. dht_root = hivemind.DHT(start=True)
  344. initial_peers = dht_root.get_visible_maddrs()
  345. dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
  346. averager1 = TestAverager(
  347. [torch.randn(3), torch.rand(5)],
  348. dht=dht1,
  349. start=True,
  350. prefix="demo-run",
  351. target_group_size=2,
  352. listen_on="127.0.0.1:*",
  353. )
  354. dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
  355. dht2.get("demo-run.all_averagers")
  356. averager2 = TestAverager(
  357. [torch.randn(3), torch.rand(5)],
  358. dht=dht2,
  359. start=True,
  360. prefix="demo-run",
  361. target_group_size=2,
  362. listen_on="127.0.0.1:*",
  363. )
  364. assert num_calls == 0
  365. got_metadata, got_tensors = averager2.load_state_from_peers()
  366. assert num_calls == 1
  367. assert got_metadata == super_metadata
  368. assert all(map(torch.allclose, got_tensors, super_tensors))
  369. super_metadata["y"] = 123
  370. super_tensors[1][2] = 9
  371. assert num_calls == 1
  372. assert got_metadata != super_metadata
  373. assert not all(map(torch.allclose, got_tensors, super_tensors))
  374. got_metadata, got_tensors = averager2.load_state_from_peers()
  375. assert num_calls == 2
  376. assert got_metadata == super_metadata
  377. assert all(map(torch.allclose, got_tensors, super_tensors))
  378. averager1.allow_state_sharing = False
  379. assert averager2.load_state_from_peers() is None
  380. averager1.allow_state_sharing = True
  381. got_metadata, got_tensors = averager2.load_state_from_peers()
  382. assert num_calls == 3
  383. assert got_metadata == super_metadata
  384. @pytest.mark.forked
  385. def test_getset_bits():
  386. dht = hivemind.DHT(start=True)
  387. averager = hivemind.averaging.DecentralizedAverager(
  388. [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2, listen_on="127.0.0.1:*"
  389. )
  390. averager.set_group_bits("00101011101010")
  391. assert averager.get_group_bits() == "00101011101010"
  392. @pytest.mark.forked
  393. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  394. torch.manual_seed(42)
  395. dht = hivemind.DHT(start=True)
  396. common_kwargs = {
  397. "dht": dht,
  398. "start": True,
  399. "listen_on": "127.0.0.1:*",
  400. "prefix": "demo-run",
  401. "target_group_size": 2,
  402. }
  403. x1 = torch.randn(n_dims, requires_grad=True)
  404. opt1 = torch.optim.Adam([x1], lr=0.05)
  405. averager1 = hivemind.averaging.TrainingAverager(
  406. opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
  407. )
  408. x2 = torch.randn(n_dims, requires_grad=True)
  409. opt2 = torch.optim.Adam([x2], lr=0.05)
  410. averager2 = hivemind.averaging.TrainingAverager(
  411. opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
  412. )
  413. a = torch.ones(n_dims)
  414. for i in range(n_steps):
  415. opt1.zero_grad()
  416. opt2.zero_grad()
  417. (x1 - a).pow(2).sum().backward()
  418. (x2 - a).pow(2).sum().backward()
  419. opt1.step()
  420. opt2.step()
  421. with torch.no_grad():
  422. x_avg = 0.5 * (x1 + x2)
  423. grad_avg = 0.5 * (x1.grad + x2.grad)
  424. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  425. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  426. f1 = averager1.step(wait=False)
  427. f2 = averager2.step(wait=False)
  428. f1.result()
  429. f2.result()
  430. assert torch.allclose(x1, x_avg)
  431. assert torch.allclose(x2, x_avg)
  432. assert torch.allclose(x1.grad, grad_avg)
  433. assert torch.allclose(x2.grad, grad_avg)
  434. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  435. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  436. averager1.shutdown()
  437. averager2.shutdown()
  438. dht.shutdown()