فهرست منبع

Fix scalar deserialization (#190)

Fix scalar deserialization
Alexey Bukhtiyarov 4 سال پیش
والد
کامیت
bb7c4d5953
2فایلهای تغییر یافته به همراه21 افزوده شده و 4 حذف شده
  1. 12 4
      hivemind/utils/grpc.py
  2. 9 0
      tests/test_util_modules.py

+ 12 - 4
hivemind/utils/grpc.py

@@ -5,7 +5,7 @@ from __future__ import annotations
 
 import os
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable, Sequence
 
 import grpc
 import numpy as np
@@ -213,11 +213,19 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
     return proto
 
 
+def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype]=None):
+    """ Helper conversion function that handles edge case with scalar deserialization """
+    if size:
+        return torch.as_tensor(array, dtype=dtype).view(*size)
+    else:
+        return torch.as_tensor(array, dtype=dtype)
+
+
 def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
     if serialized_tensor.compression == CompressionType.NONE:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
-        tensor = torch.as_tensor(array).view(*serialized_tensor.size)
+        tensor = construct_torch_tensor(array, serialized_tensor.size)
     elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
         stats_size = list(serialized_tensor.size)
         stats_size[-1] = 1
@@ -227,10 +235,10 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
         stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
         array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
-        tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
+        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
     elif serialized_tensor.compression == CompressionType.FLOAT16:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
-        tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size)
+        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
     else:
         raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
 

+ 9 - 0
tests/test_util_modules.py

@@ -121,6 +121,7 @@ async def test_await_mpfuture():
             await future
 
 
+
 def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
     torch.manual_seed(0)
     from hivemind.proto.runtime_pb2 import CompressionType
@@ -194,6 +195,14 @@ def test_serialize_tensor():
     restored = hivemind.combine_from_streaming(chunks)
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
 
+    scalar = torch.tensor(1.)
+    serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.NONE)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
+
+    serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.FLOAT16)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
+
+
 
 def test_serialize_tuple():
     test_pairs = (