test_compression.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import multiprocessing as mp
  2. from ctypes import c_int32
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import hivemind
  7. from hivemind.compression import (
  8. CompressionBase,
  9. CompressionInfo,
  10. Float16Compression,
  11. NoCompression,
  12. PerTensorCompression,
  13. RoleAdaptiveCompression,
  14. SizeAdaptiveCompression,
  15. Uniform8BitQuantization,
  16. deserialize_torch_tensor,
  17. serialize_torch_tensor,
  18. )
  19. from hivemind.compression.adaptive import AdaptiveCompressionBase
  20. from hivemind.proto.runtime_pb2 import CompressionType
  21. from test_utils.dht_swarms import launch_dht_instances
  22. @pytest.mark.forked
  23. def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
  24. torch.manual_seed(0)
  25. X = torch.randn(*size)
  26. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  27. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
  28. assert error.square().mean() < alpha
  29. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  30. assert error.square().mean() < alpha
  31. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
  32. assert error.square().mean() < beta
  33. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
  34. assert error.square().mean() < beta
  35. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
  36. assert error.square().mean() < beta
  37. zeros = torch.zeros(5, 5)
  38. for compression_type in CompressionType.values():
  39. assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
  40. @pytest.mark.forked
  41. def test_serialize_tensor():
  42. def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
  43. serialized_tensor = serialize_torch_tensor(tensor, compression)
  44. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  45. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  46. restored = hivemind.combine_from_streaming(chunks)
  47. assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
  48. tensor = torch.randn(512, 12288)
  49. for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
  50. _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
  51. _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
  52. _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
  53. _check(torch.tensor(1.0), CompressionType.NONE)
  54. _check(torch.tensor(1.0), CompressionType.FLOAT16)
  55. @pytest.mark.forked
  56. def test_allreduce_compression():
  57. """this test ensures that compression works correctly when multiple tensors have different compression types"""
  58. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  59. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  60. results = {}
  61. FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
  62. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  63. dht_instances = launch_dht_instances(2)
  64. averager1 = hivemind.averaging.DecentralizedAverager(
  65. [x.clone() for x in tensors1],
  66. dht=dht_instances[0],
  67. compression=PerTensorCompression(compression_type_pair),
  68. client_mode=True,
  69. target_group_size=2,
  70. prefix="mygroup",
  71. start=True,
  72. )
  73. averager2 = hivemind.averaging.DecentralizedAverager(
  74. [x.clone() for x in tensors2],
  75. dht=dht_instances[1],
  76. compression=PerTensorCompression(compression_type_pair),
  77. target_group_size=2,
  78. prefix="mygroup",
  79. start=True,
  80. )
  81. for future in averager1.step(wait=False), averager2.step(wait=False):
  82. future.result()
  83. with averager1.get_tensors() as averaged_tensors:
  84. results[compression_type_pair] = averaged_tensors
  85. for instance in [averager1, averager2] + dht_instances:
  86. instance.shutdown()
  87. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  88. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  89. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  90. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  91. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  92. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  93. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  94. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  95. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  96. for i in range(2):
  97. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  98. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  99. class TrackedCompression(AdaptiveCompressionBase):
  100. def __init__(self, compression: CompressionBase):
  101. self.compression = compression
  102. self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
  103. super().__init__()
  104. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  105. return self.compression
  106. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
  107. self.mp_counter.value += 1
  108. if info.part_size is not None:
  109. self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
  110. return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
  111. def make_params():
  112. return [
  113. nn.Parameter(x)
  114. for x in (
  115. torch.randn([]),
  116. torch.randn(1),
  117. torch.randn(100),
  118. torch.randn(1_000),
  119. torch.randn(5_000),
  120. torch.randn(10_000),
  121. )
  122. ]
  123. @pytest.mark.forked
  124. def test_adaptive_compression():
  125. UINT8 = TrackedCompression(Uniform8BitQuantization())
  126. FLOAT16 = TrackedCompression(Float16Compression())
  127. FLOAT32 = TrackedCompression(NoCompression())
  128. STATE_FP16 = TrackedCompression(Float16Compression())
  129. STATE_FP32 = TrackedCompression(NoCompression())
  130. averaging_compression_adaptive = RoleAdaptiveCompression(
  131. parameter=FLOAT16,
  132. gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
  133. optimizer=FLOAT32,
  134. default=FLOAT32,
  135. )
  136. state_compression_adaptive = SizeAdaptiveCompression(
  137. threshold=500,
  138. less=STATE_FP32,
  139. greater_equal=STATE_FP16,
  140. )
  141. averager1 = hivemind.TrainingAverager(
  142. opt=torch.optim.Adam(make_params()),
  143. average_parameters=True,
  144. average_gradients=True,
  145. average_opt_statistics=("exp_avg",),
  146. compression=averaging_compression_adaptive,
  147. state_compression=state_compression_adaptive,
  148. prefix="test_avgr",
  149. target_group_size=2,
  150. part_size_bytes=5_000,
  151. start=True,
  152. dht=hivemind.DHT(start=True),
  153. )
  154. averager2 = hivemind.TrainingAverager(
  155. opt=torch.optim.Adam(make_params()),
  156. average_parameters=True,
  157. average_gradients=True,
  158. average_opt_statistics=("exp_avg",),
  159. compression=averaging_compression_adaptive,
  160. state_compression=state_compression_adaptive,
  161. prefix="test_avgr",
  162. target_group_size=2,
  163. part_size_bytes=5_000,
  164. start=True,
  165. dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
  166. )
  167. futures = [averager1.step(wait=False), averager2.step(wait=False)]
  168. for future in futures:
  169. future.result()
  170. assert UINT8.mp_counter.value == 4 # half gradients: 3 tensors, 1 is split
  171. assert UINT8.mp_part_size.value == 5_000 # single byte tensors
  172. assert FLOAT16.mp_counter.value == 13 # parameters and half gradients
  173. assert FLOAT16.mp_part_size.value == 2_500 # two-byte tensors
  174. assert FLOAT32.mp_counter.value == 16 # statistics
  175. assert FLOAT32.mp_part_size.value == 1250 # four-byte tensors
  176. averager1.load_state_from_peers()
  177. assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
  178. assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned