|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
|
|
|
import os
|
|
|
import threading
|
|
|
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
|
|
|
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable, Sequence
|
|
|
|
|
|
import grpc
|
|
|
import numpy as np
|
|
@@ -213,11 +213,19 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
|
|
|
return proto
|
|
|
|
|
|
|
|
|
+def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype]=None):
|
|
|
+ """ Helper conversion function that handles edge case with scalar deserialization """
|
|
|
+ if size:
|
|
|
+ return torch.as_tensor(array, dtype=dtype).view(*size)
|
|
|
+ else:
|
|
|
+ return torch.as_tensor(array, dtype=dtype)
|
|
|
+
|
|
|
+
|
|
|
def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
|
|
|
# TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
|
|
|
if serialized_tensor.compression == CompressionType.NONE:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
|
|
|
- tensor = torch.as_tensor(array).view(*serialized_tensor.size)
|
|
|
+ tensor = construct_torch_tensor(array, serialized_tensor.size)
|
|
|
elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
|
|
|
stats_size = list(serialized_tensor.size)
|
|
|
stats_size[-1] = 1
|
|
@@ -227,10 +235,10 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
|
|
|
means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
|
|
|
stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
|
|
|
array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
|
|
|
- tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
|
|
|
+ tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
|
|
|
elif serialized_tensor.compression == CompressionType.FLOAT16:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
|
|
|
- tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size)
|
|
|
+ tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
|
|
|
|