Jelajahi Sumber

Implemented float16 compression (#106)

* Implemented float16 compression

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

Co-authored-by: justheuristic <justheuristic@gmail.com>
Vsevolod-pl 4 tahun lalu
induk
melakukan
fde83bba02
3 mengubah file dengan 21 tambahan dan 1 penghapusan
  1. 1 0
      hivemind/proto/runtime.proto
  2. 18 0
      hivemind/utils/grpc.py
  3. 2 1
      tests/test_util_modules.py

+ 1 - 0
hivemind/proto/runtime.proto

@@ -29,6 +29,7 @@ message ExpertResponse {
 enum CompressionType{
   NONE = 0;
   MEANSTD_LAST_AXIS_FLOAT16 = 1;
+  FLOAT16 = 2;
 }
 
 message Tensor {

+ 18 - 0
hivemind/utils/grpc.py

@@ -32,6 +32,20 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=tensor.shape,
             dtype='compressed_float32',
             requires_grad=tensor.requires_grad)
+    elif compression_type == CompressionType.FLOAT16:
+        assert tensor.dtype == torch.float32
+
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
+
+        data = tensor.numpy().tobytes()
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='clamped_float32',
+            requires_grad=tensor.requires_grad)
     else:
         array = tensor.numpy()
         proto = runtime_pb2.Tensor(
@@ -58,6 +72,10 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
         array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
         tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
+    elif serialized_tensor.compression == CompressionType.FLOAT16:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
+        tensor = torch.as_tensor(array).view(*serialized_tensor.size)\
+            .to(torch.float32).requires_grad_(serialized_tensor.requires_grad)
     else:
         raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
     return tensor

+ 2 - 1
tests/test_util_modules.py

@@ -128,5 +128,6 @@ def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
     assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
     assert error.square().mean() < alpha
-    return error.square().mean()
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
+    assert error.square().mean() < alpha