|
@@ -32,6 +32,20 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
|
|
|
size=tensor.shape,
|
|
|
dtype='compressed_float32',
|
|
|
requires_grad=tensor.requires_grad)
|
|
|
+ elif compression_type == CompressionType.FLOAT16:
|
|
|
+ assert tensor.dtype == torch.float32
|
|
|
+
|
|
|
+ tensor = tensor if allow_inplace else tensor.clone()
|
|
|
+ tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
|
|
|
+
|
|
|
+ data = tensor.numpy().tobytes()
|
|
|
+
|
|
|
+ proto = runtime_pb2.Tensor(
|
|
|
+ compression=compression_type,
|
|
|
+ buffer=data,
|
|
|
+ size=tensor.shape,
|
|
|
+ dtype='clamped_float32',
|
|
|
+ requires_grad=tensor.requires_grad)
|
|
|
else:
|
|
|
array = tensor.numpy()
|
|
|
proto = runtime_pb2.Tensor(
|
|
@@ -58,6 +72,10 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
|
|
|
stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
|
|
|
array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
|
|
|
tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
|
|
|
+ elif serialized_tensor.compression == CompressionType.FLOAT16:
|
|
|
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
|
|
|
+ tensor = torch.as_tensor(array).view(*serialized_tensor.size)\
|
|
|
+ .to(torch.float32).requires_grad_(serialized_tensor.requires_grad)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
|
|
|
return tensor
|