Browse Source

Clean compression/__init__.py (#460)

This PR moves the functions defined in `compression/__init__.py` to `compression/serialization.py`, so the `__init__.py` contains only the imports.
Alexander Borzunov 3 years ago
parent
commit
395af50a33

+ 1 - 44
hivemind/compression/__init__.py

@@ -2,51 +2,8 @@
 Compression strategies that reduce the network communication in .averaging, .optim and .moe
 Compression strategies that reduce the network communication in .averaging, .optim and .moe
 """
 """
 
 
-import warnings
-from typing import Dict, Optional
-
-import torch
-
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.proto import runtime_pb2
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
-    NONE=NoCompression(),
-    FLOAT16=Float16Compression(),
-    MEANSTD_16BIT=ScaledFloat16Compression(),
-    QUANTILE_8BIT=Quantile8BitQuantization(),
-    UNIFORM_8BIT=Uniform8BitQuantization(),
-)
-
-for key in runtime_pb2.CompressionType.keys():
-    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
-    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
-    assert (
-        runtime_pb2.CompressionType.Name(actual_compression_type) == key
-    ), f"Compression strategy for {key} has inconsistent type"
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor,
-    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-    info: Optional[CompressionInfo] = None,
-    allow_inplace: bool = False,
-    **kwargs,
-) -> runtime_pb2.Tensor:
-    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
-    assert tensor.device == torch.device("cpu")
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
-    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
-    return compression.compress(tensor, info, allow_inplace)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    """Restore a pytorch tensor from a protobuf message"""
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
-    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor

+ 5 - 0
hivemind/compression/base.py

@@ -1,4 +1,5 @@
 import dataclasses
 import dataclasses
+import warnings
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from enum import Enum, auto
 from enum import Enum, auto
 from typing import Any, Optional
 from typing import Any, Optional
@@ -9,6 +10,10 @@ import torch
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.tensor_descr import TensorDescriptor
 from hivemind.utils.tensor_descr import TensorDescriptor
 
 
+# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
 Key = Any
 Key = Any
 
 
 
 

+ 43 - 0
hivemind/compression/serialization.py

@@ -0,0 +1,43 @@
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)