test_averaging.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. import random
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import hivemind
  6. import hivemind.averaging.averager
  7. from hivemind.averaging.allreduce import AveragingMode
  8. from hivemind.averaging.key_manager import GroupKeyManager
  9. from hivemind.averaging.load_balancing import load_balance_peers
  10. from hivemind.p2p import PeerID
  11. from hivemind.proto.runtime_pb2 import CompressionType
  12. @pytest.mark.forked
  13. @pytest.mark.asyncio
  14. async def test_key_manager():
  15. localhvost = PeerID(b"localhvost")
  16. localhvost2 = PeerID(b"localhvost2")
  17. key_manager = GroupKeyManager(
  18. hivemind.DHT(start=True),
  19. endpoint=localhvost,
  20. prefix="test_averaging",
  21. initial_group_bits="10110",
  22. target_group_size=2,
  23. )
  24. t = hivemind.get_dht_time()
  25. key = key_manager.current_key
  26. await key_manager.declare_averager(key, localhvost, expiration_time=t + 60)
  27. await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61)
  28. q1 = await key_manager.get_averagers(key, only_active=True)
  29. await key_manager.declare_averager(key, localhvost, expiration_time=t + 66)
  30. q2 = await key_manager.get_averagers(key, only_active=True)
  31. await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61, looking_for_group=False)
  32. q3 = await key_manager.get_averagers(key, only_active=True)
  33. q4 = await key_manager.get_averagers(key, only_active=False)
  34. q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
  35. assert len(q1) == 2 and (localhvost, t + 60) in q1 and (localhvost2, t + 61) in q1
  36. assert len(q2) == 2 and (localhvost, t + 66) in q2 and (localhvost2, t + 61) in q2
  37. assert len(q3) == 1 and (localhvost, t + 66) in q3
  38. assert len(q4) == 2 and (localhvost, t + 66) in q4 and (localhvost2, t + 61) in q2
  39. assert len(q5) == 0
  40. def _test_allreduce_once(n_clients, n_aux):
  41. n_peers = 4
  42. modes = (
  43. [AveragingMode.CLIENT] * n_clients
  44. + [AveragingMode.AUX] * n_aux
  45. + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  46. )
  47. random.shuffle(modes)
  48. tensors1 = [torch.randn(123), torch.zeros(3)]
  49. tensors2 = [torch.rand(123), torch.ones(3)]
  50. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  51. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  52. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  53. reference = [
  54. sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
  55. / max(1, n_peers - n_aux)
  56. for i in range(len(tensors1))
  57. ]
  58. dht_root = hivemind.DHT(start=True)
  59. initial_peers = dht_root.get_visible_maddrs()
  60. averagers = []
  61. dhts = []
  62. for tensors, mode in zip(peer_tensors, modes):
  63. dht_instance = hivemind.DHT(start=True, initial_peers=initial_peers)
  64. dhts.append(dht_instance)
  65. averagers.append(hivemind.averaging.DecentralizedAverager(
  66. tensors,
  67. dht=dht_instance,
  68. target_group_size=4,
  69. averaging_expiration=15,
  70. prefix="mygroup",
  71. client_mode=mode == AveragingMode.CLIENT,
  72. auxiliary=mode == AveragingMode.AUX,
  73. start=True,
  74. ))
  75. futures = []
  76. for averager in averagers:
  77. futures.append(averager.step(wait=False))
  78. for future in futures:
  79. result = future.result()
  80. for averager in averagers:
  81. assert averager.endpoint in result
  82. for averager in averagers:
  83. if averager.mode != AveragingMode.AUX:
  84. with averager.get_tensors() as averaged_tensors:
  85. for ref, our in zip(reference, averaged_tensors):
  86. assert torch.allclose(ref, our, atol=1e-6)
  87. for averager in averagers:
  88. averager.shutdown()
  89. for instance in dhts:
  90. instance.shutdown()
  91. dht_root.shutdown()
  92. @pytest.mark.forked
  93. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  94. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  95. def test_allreduce_once(n_clients, n_aux):
  96. _test_allreduce_once(n_clients, n_aux)
  97. @pytest.mark.forked
  98. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  99. def test_allreduce_once_edge_cases(n_clients, n_aux):
  100. _test_allreduce_once(n_clients, n_aux)
  101. @pytest.mark.forked
  102. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  103. dht = hivemind.DHT(start=True)
  104. n_peers = 4
  105. client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
  106. random.shuffle(client_modes)
  107. tensors1 = [torch.randn(123), torch.zeros(3)]
  108. tensors2 = [torch.rand(123), torch.ones(3)]
  109. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  110. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  111. averagers = [
  112. hivemind.averaging.DecentralizedAverager(
  113. tensors,
  114. dht=dht,
  115. target_group_size=4,
  116. averaging_expiration=15,
  117. prefix="mygroup",
  118. client_mode=client_mode,
  119. start=True,
  120. )
  121. for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
  122. ]
  123. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  124. reference = [
  125. (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
  126. / sum(weights)
  127. for i in range(len(tensors1))
  128. ]
  129. futures = []
  130. for averager, weight in zip(averagers, weights):
  131. futures.append(averager.step(weight=weight, wait=False))
  132. for future in futures:
  133. future.result()
  134. for future, averager in zip(futures, averagers):
  135. with averager.get_tensors() as averaged_tensors:
  136. for ref, our in zip(reference, averaged_tensors):
  137. assert torch.allclose(ref, our, atol=1e-6)
  138. for averager in averagers:
  139. averager.shutdown()
  140. dht.shutdown()
  141. @pytest.mark.forked
  142. def test_allreduce_compression():
  143. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  144. dht = hivemind.DHT(start=True)
  145. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  146. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  147. results = {}
  148. FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
  149. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  150. averager1 = hivemind.averaging.DecentralizedAverager(
  151. [x.clone() for x in tensors1],
  152. dht=dht,
  153. compression_type=compression_type_pair,
  154. client_mode=True,
  155. target_group_size=2,
  156. prefix="mygroup",
  157. start=True,
  158. )
  159. averager2 = hivemind.averaging.DecentralizedAverager(
  160. [x.clone() for x in tensors2],
  161. dht=dht,
  162. compression_type=compression_type_pair,
  163. target_group_size=2,
  164. prefix="mygroup",
  165. start=True,
  166. )
  167. for future in averager1.step(wait=False), averager2.step(wait=False):
  168. future.result()
  169. with averager1.get_tensors() as averaged_tensors:
  170. results[compression_type_pair] = averaged_tensors
  171. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  172. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  173. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  174. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  175. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  176. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  177. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  178. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  179. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  180. for i in range(2):
  181. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  182. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  183. def compute_mean_std(averagers, unbiased=True):
  184. results = []
  185. for averager in averagers:
  186. with averager.get_tensors() as tensors:
  187. results.append([tensor.clone() for tensor in tensors])
  188. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  189. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  190. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  191. return means, stds
  192. @pytest.mark.forked
  193. def test_allreduce_grid():
  194. dht = hivemind.DHT(start=True)
  195. averagers = [
  196. hivemind.averaging.DecentralizedAverager(
  197. averaged_tensors=[torch.randn(3)],
  198. dht=dht,
  199. target_group_size=2,
  200. prefix="mygroup",
  201. initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
  202. start=True,
  203. )
  204. for i in range(8)
  205. ]
  206. [means0], [stds0] = compute_mean_std(averagers)
  207. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  208. prev_means, prev_stds = means0, stds0
  209. for i in range(5):
  210. step_futures = [averager.step(wait=False) for averager in averagers]
  211. groups = [future.result() for future in step_futures]
  212. [means], [stds] = compute_mean_std(averagers)
  213. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  214. assert all(len(group) == 2 for group in groups)
  215. if i <= 2:
  216. assert torch.all(torch.le(stds, prev_stds))
  217. else:
  218. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  219. for averager in averagers:
  220. averager.shutdown()
  221. dht.shutdown()
  222. @pytest.mark.forked
  223. def test_allgather():
  224. dht = hivemind.DHT(start=True)
  225. averagers = [
  226. hivemind.averaging.DecentralizedAverager(
  227. [torch.ones(1)],
  228. dht=dht,
  229. target_group_size=4,
  230. averaging_expiration=15,
  231. prefix="mygroup",
  232. initial_group_bits="000",
  233. start=True,
  234. )
  235. for _ in range(8)
  236. ]
  237. futures = []
  238. for i, averager in enumerate(averagers):
  239. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
  240. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  241. reference_metadata = {
  242. averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
  243. }
  244. for future in futures:
  245. gathered = future.result()
  246. assert len(gathered) == 4
  247. for endpoint in gathered:
  248. assert gathered[endpoint] == reference_metadata[endpoint]
  249. for averager in averagers:
  250. averager.shutdown()
  251. dht.shutdown()
  252. def get_cost(vector_size, partitions, bandwidths):
  253. return max(
  254. (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
  255. for i in range(len(partitions))
  256. )
  257. def check_optimality(vector_size, bandwidths, ref_partitions):
  258. partitions = list(load_balance_peers(vector_size, bandwidths))
  259. assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
  260. @pytest.mark.forked
  261. def test_load_balancing():
  262. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  263. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  264. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  265. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  266. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  267. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  268. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  269. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  270. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  271. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  272. assert load_balance_peers(100, (None, None)) == (50, 50)
  273. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  274. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  275. with pytest.raises(AssertionError):
  276. load_balance_peers(100, (0, 0, 0))
  277. for i in range(10):
  278. vector_size = np.random.randint(1, 1024 ** 3)
  279. num_peers = np.random.randint(1, 256)
  280. scale = 1e-9 + np.random.rand() * 1e5
  281. bandwidths = np.random.rand(num_peers) * scale + 1e-6
  282. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  283. assignment = load_balance_peers(vector_size, bandwidths, min_size)
  284. assert np.sum(assignment) == vector_size
  285. assert np.min(assignment) >= 0
  286. @pytest.mark.forked
  287. def test_too_few_peers():
  288. dht = hivemind.DHT(start=True)
  289. averagers = [
  290. hivemind.averaging.DecentralizedAverager(
  291. averaged_tensors=[torch.randn(3)],
  292. dht=dht,
  293. target_group_size=2,
  294. averaging_expiration=1,
  295. request_timeout=0.5,
  296. prefix="mygroup",
  297. initial_group_bits=bin(i)[2:].rjust(3, "0"),
  298. start=True,
  299. )
  300. for i in range(4)
  301. ]
  302. step_futures = [averager.step(wait=False) for averager in averagers]
  303. for future in step_futures:
  304. assert len(future.result()) == 2
  305. for averager in averagers:
  306. averager.shutdown()
  307. dht.shutdown()
  308. @pytest.mark.forked
  309. def test_overcrowded(num_peers=16):
  310. dht = hivemind.DHT(start=True)
  311. averagers = [
  312. hivemind.averaging.DecentralizedAverager(
  313. averaged_tensors=[torch.randn(3)],
  314. dht=dht,
  315. target_group_size=2,
  316. averaging_expiration=1,
  317. request_timeout=0.5,
  318. prefix="mygroup",
  319. initial_group_bits="",
  320. start=True,
  321. )
  322. for _ in range(num_peers)
  323. ]
  324. for t in range(5):
  325. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  326. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  327. for averager in averagers:
  328. averager.shutdown()
  329. dht.shutdown()
  330. @pytest.mark.forked
  331. def test_load_state_from_peers():
  332. num_calls = 0
  333. super_metadata = dict(x=123)
  334. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  335. class TestAverager(hivemind.averaging.DecentralizedAverager):
  336. def get_current_state(self):
  337. """
  338. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  339. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  340. """
  341. nonlocal num_calls, super_metadata, super_tensors
  342. num_calls += 1
  343. return super_metadata, super_tensors
  344. dht_root = hivemind.DHT(start=True)
  345. initial_peers = dht_root.get_visible_maddrs()
  346. dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
  347. averager1 = TestAverager(
  348. [torch.randn(3), torch.rand(5)],
  349. dht=dht1,
  350. start=True,
  351. prefix="demo-run",
  352. target_group_size=2,
  353. )
  354. dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
  355. dht2.get("demo-run.all_averagers")
  356. averager2 = TestAverager(
  357. [torch.randn(3), torch.rand(5)],
  358. dht=dht2,
  359. start=True,
  360. prefix="demo-run",
  361. target_group_size=2,
  362. )
  363. assert num_calls == 0
  364. got_metadata, got_tensors = averager2.load_state_from_peers()
  365. assert num_calls == 1
  366. assert got_metadata == super_metadata
  367. assert all(map(torch.allclose, got_tensors, super_tensors))
  368. super_metadata["y"] = 123
  369. super_tensors[1][2] = 9
  370. assert num_calls == 1
  371. assert got_metadata != super_metadata
  372. assert not all(map(torch.allclose, got_tensors, super_tensors))
  373. got_metadata, got_tensors = averager2.load_state_from_peers()
  374. assert num_calls == 2
  375. assert got_metadata == super_metadata
  376. assert all(map(torch.allclose, got_tensors, super_tensors))
  377. averager1.allow_state_sharing = False
  378. assert averager2.load_state_from_peers() is None
  379. averager1.allow_state_sharing = True
  380. got_metadata, got_tensors = averager2.load_state_from_peers()
  381. assert num_calls == 3
  382. assert got_metadata == super_metadata
  383. @pytest.mark.forked
  384. def test_getset_bits():
  385. dht = hivemind.DHT(start=True)
  386. averager = hivemind.averaging.DecentralizedAverager(
  387. [torch.randn(3)],
  388. dht=dht,
  389. start=True,
  390. prefix="test_prefix",
  391. target_group_size=2,
  392. )
  393. averager.set_group_bits("00101011101010")
  394. assert averager.get_group_bits() == "00101011101010"
  395. @pytest.mark.forked
  396. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  397. torch.manual_seed(42)
  398. dht = hivemind.DHT(start=True)
  399. common_kwargs = {
  400. "dht": dht,
  401. "start": True,
  402. "prefix": "demo-run",
  403. "target_group_size": 2,
  404. }
  405. x1 = torch.randn(n_dims, requires_grad=True)
  406. opt1 = torch.optim.Adam([x1], lr=0.05)
  407. averager1 = hivemind.averaging.TrainingAverager(
  408. opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
  409. )
  410. x2 = torch.randn(n_dims, requires_grad=True)
  411. opt2 = torch.optim.Adam([x2], lr=0.05)
  412. averager2 = hivemind.averaging.TrainingAverager(
  413. opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
  414. )
  415. a = torch.ones(n_dims)
  416. for i in range(n_steps):
  417. opt1.zero_grad()
  418. opt2.zero_grad()
  419. (x1 - a).pow(2).sum().backward()
  420. (x2 - a).pow(2).sum().backward()
  421. opt1.step()
  422. opt2.step()
  423. with torch.no_grad():
  424. x_avg = 0.5 * (x1 + x2)
  425. grad_avg = 0.5 * (x1.grad + x2.grad)
  426. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  427. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  428. f1 = averager1.step(wait=False)
  429. f2 = averager2.step(wait=False)
  430. f1.result()
  431. f2.result()
  432. assert torch.allclose(x1, x_avg)
  433. assert torch.allclose(x2, x_avg)
  434. assert torch.allclose(x1.grad, grad_avg)
  435. assert torch.allclose(x2.grad, grad_avg)
  436. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  437. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  438. averager1.shutdown()
  439. averager2.shutdown()
  440. dht.shutdown()