test_compression.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import multiprocessing as mp
  2. from ctypes import c_int32
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import hivemind
  7. from hivemind.compression import (
  8. CompressionBase,
  9. CompressionInfo,
  10. Float16Compression,
  11. NoCompression,
  12. PerTensorCompression,
  13. RoleAdaptiveCompression,
  14. SizeAdaptiveCompression,
  15. Uniform8BitQuantization,
  16. deserialize_torch_tensor,
  17. serialize_torch_tensor,
  18. )
  19. from hivemind.compression.adaptive import AdaptiveCompressionBase
  20. from hivemind.proto.runtime_pb2 import CompressionType
  21. from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
  22. from test_utils.dht_swarms import launch_dht_instances
  23. @pytest.mark.forked
  24. def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
  25. torch.manual_seed(0)
  26. X = torch.randn(*size)
  27. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  28. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
  29. assert error.square().mean() < alpha
  30. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  31. assert error.square().mean() < alpha
  32. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
  33. assert error.square().mean() < beta
  34. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
  35. assert error.square().mean() < beta
  36. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
  37. assert error.square().mean() < beta
  38. zeros = torch.zeros(5, 5)
  39. for compression_type in CompressionType.values():
  40. assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
  41. @pytest.mark.forked
  42. def test_serialize_tensor():
  43. def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
  44. serialized_tensor = serialize_torch_tensor(tensor, compression)
  45. chunks = list(split_for_streaming(serialized_tensor, chunk_size))
  46. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  47. restored = combine_from_streaming(chunks)
  48. assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
  49. tensor = torch.randn(512, 12288)
  50. for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
  51. _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
  52. _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
  53. _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
  54. _check(torch.tensor(1.0), CompressionType.NONE)
  55. _check(torch.tensor(1.0), CompressionType.FLOAT16)
  56. @pytest.mark.forked
  57. def test_allreduce_compression():
  58. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  59. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  60. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  61. results = {}
  62. FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
  63. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  64. dht_instances = launch_dht_instances(2)
  65. averager1 = hivemind.averaging.DecentralizedAverager(
  66. [x.clone() for x in tensors1],
  67. dht=dht_instances[0],
  68. compression=PerTensorCompression(compression_type_pair),
  69. client_mode=True,
  70. target_group_size=2,
  71. prefix="mygroup",
  72. start=True,
  73. )
  74. averager2 = hivemind.averaging.DecentralizedAverager(
  75. [x.clone() for x in tensors2],
  76. dht=dht_instances[1],
  77. compression=PerTensorCompression(compression_type_pair),
  78. target_group_size=2,
  79. prefix="mygroup",
  80. start=True,
  81. )
  82. for future in averager1.step(wait=False), averager2.step(wait=False):
  83. future.result()
  84. with averager1.get_tensors() as averaged_tensors:
  85. results[compression_type_pair] = averaged_tensors
  86. for instance in [averager1, averager2] + dht_instances:
  87. instance.shutdown()
  88. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  89. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  90. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  91. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  92. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  93. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  94. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  95. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  96. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  97. for i in range(2):
  98. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  99. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  100. class TrackedCompression(AdaptiveCompressionBase):
  101. def __init__(self, compression: CompressionBase):
  102. self.compression = compression
  103. self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
  104. super().__init__()
  105. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  106. return self.compression
  107. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
  108. self.mp_counter.value += 1
  109. if info.part_size is not None:
  110. self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
  111. return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
  112. def make_params():
  113. return [
  114. nn.Parameter(x)
  115. for x in (
  116. torch.randn([]),
  117. torch.randn(1),
  118. torch.randn(100),
  119. torch.randn(1_000),
  120. torch.randn(5_000),
  121. torch.randn(10_000),
  122. )
  123. ]
  124. @pytest.mark.forked
  125. def test_adaptive_compression():
  126. UINT8 = TrackedCompression(Uniform8BitQuantization())
  127. FLOAT16 = TrackedCompression(Float16Compression())
  128. FLOAT32 = TrackedCompression(NoCompression())
  129. STATE_FP16 = TrackedCompression(Float16Compression())
  130. STATE_FP32 = TrackedCompression(NoCompression())
  131. averaging_compression_adaptive = RoleAdaptiveCompression(
  132. parameter=FLOAT16,
  133. gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
  134. optimizer=FLOAT32,
  135. default=FLOAT32,
  136. )
  137. state_compression_adaptive = SizeAdaptiveCompression(
  138. threshold=500,
  139. less=STATE_FP32,
  140. greater_equal=STATE_FP16,
  141. )
  142. averager1 = hivemind.TrainingAverager(
  143. opt=torch.optim.Adam(make_params()),
  144. average_parameters=True,
  145. average_gradients=True,
  146. average_opt_statistics=("exp_avg",),
  147. compression=averaging_compression_adaptive,
  148. state_compression=state_compression_adaptive,
  149. prefix="test_avgr",
  150. target_group_size=2,
  151. part_size_bytes=5_000,
  152. start=True,
  153. dht=hivemind.DHT(start=True),
  154. )
  155. averager2 = hivemind.TrainingAverager(
  156. opt=torch.optim.Adam(make_params()),
  157. average_parameters=True,
  158. average_gradients=True,
  159. average_opt_statistics=("exp_avg",),
  160. compression=averaging_compression_adaptive,
  161. state_compression=state_compression_adaptive,
  162. prefix="test_avgr",
  163. target_group_size=2,
  164. part_size_bytes=5_000,
  165. start=True,
  166. dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
  167. )
  168. futures = [averager1.step(wait=False), averager2.step(wait=False)]
  169. for future in futures:
  170. future.result()
  171. assert UINT8.mp_counter.value == 4 # half gradients: 3 tensors, 1 is split
  172. assert UINT8.mp_part_size.value == 5_000 # single byte tensors
  173. assert FLOAT16.mp_counter.value == 13 # parameters and half gradients
  174. assert FLOAT16.mp_part_size.value == 2_500 # two-byte tensors
  175. assert FLOAT32.mp_counter.value == 16 # statistics
  176. assert FLOAT32.mp_part_size.value == 1250 # four-byte tensors
  177. averager1.load_state_from_peers()
  178. state_metadata, state_tensors, infos = averager1.get_current_state()
  179. assert STATE_FP16.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() >= 500])
  180. assert STATE_FP32.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() < 500])
  181. assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned