浏览代码

:Merge branch 'master' of github.com:learning-at-home/hivemind into power_ef_new

Artem Chumachenko 3 年之前
父节点
当前提交
08dd8572a2
共有 3 个文件被更改,包括 49 次插入44 次删除
  1. 1 44
      hivemind/compression/__init__.py
  2. 5 0
      hivemind/compression/base.py
  3. 43 0
      hivemind/compression/serialization.py

+ 1 - 44
hivemind/compression/__init__.py

@@ -2,51 +2,8 @@
 Compression strategies that reduce the network communication in .averaging, .optim and .moe
 """
 
-import warnings
-from typing import Dict, Optional
-
-import torch
-
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.proto import runtime_pb2
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
-    NONE=NoCompression(),
-    FLOAT16=Float16Compression(),
-    MEANSTD_16BIT=ScaledFloat16Compression(),
-    QUANTILE_8BIT=Quantile8BitQuantization(),
-    UNIFORM_8BIT=Uniform8BitQuantization(),
-)
-
-for key in runtime_pb2.CompressionType.keys():
-    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
-    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
-    assert (
-        runtime_pb2.CompressionType.Name(actual_compression_type) == key
-    ), f"Compression strategy for {key} has inconsistent type"
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor,
-    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-    info: Optional[CompressionInfo] = None,
-    allow_inplace: bool = False,
-    **kwargs,
-) -> runtime_pb2.Tensor:
-    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
-    assert tensor.device == torch.device("cpu")
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
-    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
-    return compression.compress(tensor, info, allow_inplace)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    """Restore a pytorch tensor from a protobuf message"""
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
-    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor

+ 5 - 0
hivemind/compression/base.py

@@ -1,4 +1,5 @@
 import dataclasses
+import warnings
 from abc import ABC, abstractmethod
 from enum import Enum, auto
 from typing import Any, Optional
@@ -9,6 +10,10 @@ import torch
 from hivemind.proto import runtime_pb2
 from hivemind.utils.tensor_descr import TensorDescriptor
 
+# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
 Key = Any
 
 

+ 43 - 0
hivemind/compression/serialization.py

@@ -0,0 +1,43 @@
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)