Преглед изворни кода

Improve bfloat16 serialization (backward compatible) (#553)

-    added bfloat16 serialization that sends 2 bytes per value (previously, we sent 4);
-    changed de-serialization code so it supports both modes of serialization.
-    the new mode can be enabled via export USE_LEGACY_BFLOAT16=0
-    tested in pytorch 1.12 and 1.13

---------

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
justheuristic пре 2 година
родитељ
комит
7d1bb7d47c
2 измењених фајлова са 24 додато и 7 уклоњено
  1. 21 6
      hivemind/compression/base.py
  2. 3 1
      tests/test_compression.py

+ 21 - 6
hivemind/compression/base.py

@@ -1,4 +1,5 @@
 import dataclasses
+import os
 import warnings
 from abc import ABC, abstractmethod
 from enum import Enum, auto
@@ -13,6 +14,7 @@ from hivemind.utils.tensor_descr import TensorDescriptor
 # While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
 warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning)
 
+USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1)))
 
 Key = Any
 
@@ -81,26 +83,39 @@ class NoCompression(CompressionBase):
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
         tensor = tensor.detach()
+        shape = tensor.shape
         dtype_name = str(tensor.dtype).lstrip("torch.")
+        raw_data = tensor
         if tensor.dtype == torch.bfloat16:
-            tensor = tensor.to(torch.float32)
+            if USE_LEGACY_BFLOAT16:
+                raw_data = tensor.to(torch.float32)
+            else:
+                typed_storage = tensor.storage()
+                storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
+                raw_data = torch.tensor(storage, dtype=torch.int8)
 
         return runtime_pb2.Tensor(
             compression=self.compression_type,
-            buffer=tensor.numpy().tobytes(),
-            size=tensor.shape,
+            buffer=raw_data.numpy().tobytes(),
+            size=shape,
             dtype=dtype_name,
             requires_grad=tensor.requires_grad,
         )
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        shape = torch.Size(serialized_tensor.size)
         if serialized_tensor.dtype == "bfloat16":
-            array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
-            tensor = torch.as_tensor(array, dtype=torch.bfloat16)
+            if len(serialized_tensor.buffer) // shape.numel() == 4:  # legacy mode: convert to fp32
+                array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
+                tensor = torch.as_tensor(array, dtype=torch.bfloat16)
+            else:  # efficient mode: send bfloat16 data directly
+                storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
+                storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
+                tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
         else:
             array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
             tensor = torch.as_tensor(array)
-        return tensor.reshape(tuple(serialized_tensor.size))
+        return tensor.reshape(shape)
 
     def estimate_compression_ratio(self, info: CompressionInfo) -> float:
         return 1.0

+ 3 - 1
tests/test_compression.py

@@ -68,8 +68,10 @@ def test_serialize_tensor():
     _check(torch.tensor(1.0), CompressionType.FLOAT16)
 
 
+@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
 @pytest.mark.forked
-def test_serialize_bfloat16():
+def test_serialize_bfloat16(use_legacy_bfloat16: bool):
+    hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
     tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
     _check(tensor, CompressionType.NONE)
     _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)