Просмотр исходного кода

Tensor compression (part1) (#102)

* Implemented tensor compression

* Test compression argument passing

* Fixed naming error

* Fixed argument error

* Fixed gradient error

* Implemented tensor compression in RemoteExpert

* Fixed typo error

* Fixed TypeError: 'generator' object is not subscriptable

* Test torch error fix

* Implemented tensor compression in connection handler (server response)

* Fixed error

* CircleCI fix

* Fixed order

* Removed not implemented compression type

* missing \n

* TODO

* use iterators & change schema

* more reasonable schema names

* typo

* add compression to MoE.py

* Implemented more efficient vector compression

* Fixed errors in deserialize_tensors

* Created test for vector compression

* Fixed dtype error in deserialize_tensor

* Fixed error in serialize_tensor

* Fixed error in deserialize_tensor

* Deleted typo

* Fixed TypeError in reshape in deserialize_torch_tensor

* Fixed wrong shape in deserialize_torch_tensor

* Fixed error on changing tensor shape

* Experimentally found out the alpha for test_vector_compression

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Changed dtype of compressed float32 tensor to compressed_float32

* Update hivemind/server/connection_handler.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/server/connection_handler.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* compressed_tensor -> serialized_tensor

* Incremented version

Co-authored-by: justheuristic <justheuristic@gmail.com>
Vsevolod-pl 4 лет назад
Родитель
Сommit
d4d9da9d3e

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.3'
+__version__ = '0.8.4'

+ 15 - 9
hivemind/client/expert.py

@@ -1,6 +1,6 @@
 import pickle
 from functools import lru_cache
-from typing import Tuple, Optional, Any
+from typing import Tuple, Optional, Any, Dict
 
 import grpc
 import grpc.experimental.aio
@@ -9,6 +9,7 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 
@@ -61,7 +62,7 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
@@ -81,14 +82,17 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @staticmethod
     def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
-                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+                info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.stub = uid, stub
+        ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
+        serialized_tensors = [serialize_torch_tensor(inp, proto.compression)
+                              for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))]
+
         outputs = stub.forward(
-            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
@@ -97,10 +101,12 @@ class _RemoteModuleCall(torch.autograd.Function):
     @staticmethod
     @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [serialize_torch_tensor(tensor, proto.compression)
+                              for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
 
-        grad_inputs = ctx.stub.backward(
-            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, *deserialized_grad_inputs)
+        return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 28 - 23
hivemind/client/moe.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import time
-from typing import Tuple, List, Optional, Awaitable, Set, Dict
+from typing import Tuple, List, Optional, Awaitable, Set, Dict, Any
 
 import grpc.experimental.aio
 import torch
@@ -58,7 +58,7 @@ class RemoteMixtureOfExperts(nn.Module):
         self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-        self._outputs_schema = None  # expert['info'][outputs_schema] from one of experts in the grid
+        self._expert_info = None  # expert['info'] from one of experts in the grid
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
@@ -88,8 +88,8 @@ class RemoteMixtureOfExperts(nn.Module):
         # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
-            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
-            self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
+            self.backward_timeout, self.loop, self.info, *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -99,7 +99,7 @@ class RemoteMixtureOfExperts(nn.Module):
         averaged_outputs_flat = [
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
-        return nested_pack(averaged_outputs_flat, self.outputs_schema)
+        return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
 
     async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
         """
@@ -139,9 +139,9 @@ class RemoteMixtureOfExperts(nn.Module):
             beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
             beam_experts = list(best_alive_prefixes.values())
 
-        if self._outputs_schema is None:
+        if self._expert_info is None:
             try:
-                self._outputs_schema = beam_experts[0].info['outputs_schema']
+                self._expert_info = beam_experts[0].info
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
@@ -182,15 +182,15 @@ class RemoteMixtureOfExperts(nn.Module):
         return scores
 
     @property
-    def outputs_schema(self):
-        if self._outputs_schema is None:
+    def info(self):
+        if self._expert_info is None:
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
             dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
-            self._outputs_schema = dummy_experts[0].info['outputs_schema']
-        return self._outputs_schema
+            self._expert_info = dummy_experts[0].info
+        return self._expert_info
 
 
 class _RemoteCallMany(torch.autograd.Function):
@@ -206,7 +206,7 @@ class _RemoteCallMany(torch.autograd.Function):
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                loop: asyncio.base_events.BaseEventLoop, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
@@ -215,7 +215,7 @@ class _RemoteCallMany(torch.autograd.Function):
         async def _forward():
             # dispatch tasks to all remote experts, await responses
             pending_tasks = {
-                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
+                asyncio.create_task(cls._forward_one_expert((i, j), expert, info, flat_inputs_per_sample[i]))
                 for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
             }
             alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
@@ -239,7 +239,8 @@ class _RemoteCallMany(torch.autograd.Function):
 
             # save individual outputs for backward pass
             ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
-            ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+            ctx._saved_non_tensors = loop, info, backward_k_min, backward_timeout,\
+                                     timeout_after_k_min, experts_per_sample
             return (mask,) + tuple(outputs)
 
         return loop.run_until_complete(_forward())
@@ -248,7 +249,7 @@ class _RemoteCallMany(torch.autograd.Function):
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
-        loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        loop, info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
         dummy_grad_mask, *flat_grad_outputs = raw_grads
         num_samples, max_experts = dummy_grad_mask.shape
@@ -261,8 +262,8 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks = set()
             for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                         inputs_per_expert, grad_outputs_per_expert):
-                pending_tasks.add(asyncio.create_task(
-                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)))
+                pending_tasks.add(asyncio.create_task(cls._backward_one_expert(
+                    (i, j), expert_per_sample[i.item()][j.item()], info, inputs_ij, grad_outputs_ij)))
 
             backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
                 pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
@@ -281,28 +282,32 @@ class _RemoteCallMany(torch.autograd.Function):
 
                 grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
 
-            return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+            return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
 
         return loop.run_until_complete(_backward())
 
     @staticmethod
-    async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
+    async def _forward_one_expert(
+            grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any], inputs: Tuple[torch.Tensor]):
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
         try:
             outputs = await stub.forward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression) for tensor, proto in 
+                                         zip(inputs, nested_flatten(info['forward_schema']))]))
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
 
     @staticmethod
-    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
+    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any],
                                    inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
-        payload = tuple(nested_flatten((inputs, grad_outputs)))
+        inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs)))
+        backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
         try:
             grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression)
+                                         for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]))
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")

+ 6 - 0
hivemind/proto/runtime.proto

@@ -26,10 +26,16 @@ message ExpertResponse {
   repeated Tensor tensors = 2;
 }
 
+enum CompressionType{
+  NONE = 0;
+  MEANSTD_LAST_AXIS_FLOAT16 = 1;
+}
+
 message Tensor {
   bytes buffer = 1;
   repeated uint32 size = 2;
   bool requires_grad = 3;
   string dtype = 4;
+  CompressionType compression = 5;
 }
 

+ 6 - 3
hivemind/server/connection_handler.py

@@ -10,7 +10,7 @@ import uvloop
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
 
 logger = get_logger(__name__)
 
@@ -60,11 +60,14 @@ class ConnectionHandler(mp.Process):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
+        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
+                               in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))]
+
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
+        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
+                               in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 5 - 3
hivemind/server/expert_backend.py

@@ -53,9 +53,11 @@ class ExpertBackend(nn.Module):
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
-        self.outputs_schema = outputs_schema
-        self.forward_schema = (self.args_schema, self.kwargs_schema)
-        self.backward_schema = (self.forward_schema, self.outputs_schema)  # original inputs and grad w.r.t. outputs
+        self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
+        self.outputs_schema = outputs_schema  # outputs from forward
+
+        self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
+        self.grad_inputs_schema = self.forward_schema  # outputs from backward
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 

+ 49 - 10
hivemind/utils/grpc.py

@@ -6,19 +6,58 @@ import numpy as np
 import torch
 
 from hivemind.proto import runtime_pb2
+from hivemind.proto.runtime_pb2 import CompressionType
 
+FP16_MAX = 65_504
+
+
+def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, 
+                           allow_inplace=False) -> runtime_pb2.Tensor:
+    if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
+        assert tensor.dtype == torch.float32
+
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+
+        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        tensor.div_(stds)
+        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
+
+        data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='compressed_float32',
+            requires_grad=tensor.requires_grad)
+    else:
+        array = tensor.numpy()
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad)
 
-def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
-    array = tensor.numpy()
-    proto = runtime_pb2.Tensor(
-        buffer=array.tobytes(),
-        size=array.shape,
-        dtype=array.dtype.name,
-        requires_grad=tensor.requires_grad)
     return proto
 
 
-def deserialize_torch_tensor(tensor: runtime_pb2.Tensor) -> torch.Tensor:
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
-    array = np.frombuffer(tensor.buffer, dtype=np.dtype(tensor.dtype)).copy()
-    return torch.as_tensor(array).view(tuple(tensor.size)).requires_grad_(tensor.requires_grad)
+    if serialized_tensor.compression == CompressionType.NONE:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
+        tensor = torch.as_tensor(array).view(*serialized_tensor.size).requires_grad_(serialized_tensor.requires_grad)
+    elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
+        stats_size = list(serialized_tensor.size)
+        stats_size[-1] = 1
+        stats_count = np.prod(stats_size)
+        means, stds = serialized_tensor.buffer[-8*stats_count:-4*stats_count], serialized_tensor.buffer[-4*stats_count:]
+        means = torch.as_tensor(np.frombuffer(means, dtype=np.float32)).view(*stats_size)
+        stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
+        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
+        tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
+    else:
+        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
+    return tensor

+ 3 - 0
hivemind/utils/tensor_descr.py

@@ -2,6 +2,8 @@ from dataclasses import dataclass, asdict
 
 import torch
 
+from hivemind.proto.runtime_pb2 import CompressionType
+
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
@@ -18,6 +20,7 @@ class TensorDescriptor(DescriptorBase):
     device: torch.device = None
     requires_grad: bool = False
     pin_memory: bool = False
+    compression: CompressionType = CompressionType.NONE
 
     @property
     def shape(self):

+ 1 - 1
tests/test_moe.py

@@ -45,7 +45,7 @@ def test_call_many():
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
-            asyncio.new_event_loop(), inputs
+            asyncio.new_event_loop(), e1.info, inputs
         )
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, 64)

+ 16 - 0
tests/test_util_modules.py

@@ -1,4 +1,5 @@
 import asyncio
+import torch
 
 import pytest
 import hivemind
@@ -81,6 +82,7 @@ def test_await_mpfuture():
     async def _run():
         # await result
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_assign():
             assert f2.set_running_or_notify_cancel() is True
             await asyncio.sleep(0.1)
@@ -93,6 +95,7 @@ def test_await_mpfuture():
 
         # await cancel
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_cancel():
             await asyncio.sleep(0.1)
             f1.cancel()
@@ -104,6 +107,7 @@ def test_await_mpfuture():
 
         # await exception
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_raise():
             await asyncio.sleep(0.1)
             f1.set_exception(SystemError())
@@ -114,3 +118,15 @@ def test_await_mpfuture():
                 await future
 
     asyncio.new_event_loop().run_until_complete(_run())
+
+
+def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
+    torch.manual_seed(0)
+    from hivemind.proto.runtime_pb2 import CompressionType
+    from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
+    assert error.square().mean() < alpha
+    return error.square().mean()
+