|
@@ -1,4 +1,5 @@
|
|
|
import dataclasses
|
|
|
+import os
|
|
|
import warnings
|
|
|
from abc import ABC, abstractmethod
|
|
|
from enum import Enum, auto
|
|
@@ -13,6 +14,7 @@ from hivemind.utils.tensor_descr import TensorDescriptor
|
|
|
# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
|
|
|
warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning)
|
|
|
|
|
|
+USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1)))
|
|
|
|
|
|
Key = Any
|
|
|
|
|
@@ -81,26 +83,39 @@ class NoCompression(CompressionBase):
|
|
|
|
|
|
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
|
|
|
tensor = tensor.detach()
|
|
|
+ shape = tensor.shape
|
|
|
dtype_name = str(tensor.dtype).lstrip("torch.")
|
|
|
+ raw_data = tensor
|
|
|
if tensor.dtype == torch.bfloat16:
|
|
|
- tensor = tensor.to(torch.float32)
|
|
|
+ if USE_LEGACY_BFLOAT16:
|
|
|
+ raw_data = tensor.to(torch.float32)
|
|
|
+ else:
|
|
|
+ typed_storage = tensor.storage()
|
|
|
+ storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
|
|
|
+ raw_data = torch.tensor(storage, dtype=torch.int8)
|
|
|
|
|
|
return runtime_pb2.Tensor(
|
|
|
compression=self.compression_type,
|
|
|
- buffer=tensor.numpy().tobytes(),
|
|
|
- size=tensor.shape,
|
|
|
+ buffer=raw_data.numpy().tobytes(),
|
|
|
+ size=shape,
|
|
|
dtype=dtype_name,
|
|
|
requires_grad=tensor.requires_grad,
|
|
|
)
|
|
|
|
|
|
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
|
|
|
+ shape = torch.Size(serialized_tensor.size)
|
|
|
if serialized_tensor.dtype == "bfloat16":
|
|
|
- array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
|
|
|
- tensor = torch.as_tensor(array, dtype=torch.bfloat16)
|
|
|
+ if len(serialized_tensor.buffer) // shape.numel() == 4: # legacy mode: convert to fp32
|
|
|
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
|
|
|
+ tensor = torch.as_tensor(array, dtype=torch.bfloat16)
|
|
|
+ else: # efficient mode: send bfloat16 data directly
|
|
|
+ storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
|
|
|
+ storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
|
|
|
+ tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
|
|
|
else:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
|
|
|
tensor = torch.as_tensor(array)
|
|
|
- return tensor.reshape(tuple(serialized_tensor.size))
|
|
|
+ return tensor.reshape(shape)
|
|
|
|
|
|
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
|
|
|
return 1.0
|