|
@@ -87,12 +87,11 @@ class NoCompression(CompressionBase):
|
|
|
dtype_name = str(tensor.dtype).lstrip("torch.")
|
|
|
raw_data = tensor
|
|
|
if tensor.dtype == torch.bfloat16:
|
|
|
- if USE_LEGACY_BFLOAT16:
|
|
|
+ if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
|
|
|
raw_data = tensor.to(torch.float32)
|
|
|
- else:
|
|
|
- typed_storage = tensor.storage()
|
|
|
- storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
|
|
|
- raw_data = torch.tensor(storage, dtype=torch.int8)
|
|
|
+ else: # efficient mode: send bfloat16 data directly
|
|
|
+ # reinterpret_cast to an arbitrary 2-byte type supported by numpy
|
|
|
+ raw_data = tensor.view(torch.int16)
|
|
|
|
|
|
return runtime_pb2.Tensor(
|
|
|
compression=self.compression_type,
|
|
@@ -106,13 +105,13 @@ class NoCompression(CompressionBase):
|
|
|
shape = torch.Size(serialized_tensor.size)
|
|
|
if serialized_tensor.dtype == "bfloat16":
|
|
|
numel = shape.numel()
|
|
|
- if numel > 0 and len(serialized_tensor.buffer) // numel == 4: # legacy mode: convert to fp32
|
|
|
+ if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
|
|
|
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
|
|
|
- else: # efficient mode: send bfloat16 data directly
|
|
|
- storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
|
|
|
- storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
|
|
|
- tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
|
|
|
+ else:
|
|
|
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
|
|
|
+ # reinterpret_cast from an arbitrary 2-byte type supported by numpy
|
|
|
+ tensor = torch.as_tensor(array).view(torch.bfloat16)
|
|
|
else:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
|
|
|
tensor = torch.as_tensor(array)
|