Browse Source

Fixed nan when compressing the tensor of zeros (#266)

* Fixed the occurrence of nan when compressing the tensor of zeros

* Implemented simple test for compressing tensor of zeros

* Improved simple test for compressing tensor of zeros

* Fixed simple test for compressing tensor of zeros
Vsevolod-pl 4 years ago
parent
commit
8673071bc0
2 changed files with 6 additions and 1 deletions
  1. 2 1
      hivemind/utils/compression.py
  2. 4 0
      tests/test_util_modules.py

+ 2 - 1
hivemind/utils/compression.py

@@ -8,7 +8,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.threading import run_in_background
 
-FP16_MAX = 65_504
+FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BITS_QUANTILE_COMPRESSION = 8
@@ -86,6 +86,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
         tensor.sub_(means)
 
         stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        stds.clamp_min_(FP32_EPS)
         tensor.div_(stds)
         tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
 

+ 4 - 0
tests/test_util_modules.py

@@ -138,6 +138,10 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
 
+    zeros = torch.zeros(5,5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_channel_cache():