Browse Source

Tensor compression (part1) (#102)

* Implemented tensor compression

* Test compression argument passing

* Fixed naming error

* Fixed argument error

* Fixed gradient error

* Implemented tensor compression in RemoteExpert

* Fixed typo error

* Fixed TypeError: 'generator' object is not subscriptable

* Test torch error fix

* Implemented tensor compression in connection handler (server response)

* Fixed error

* CircleCI fix

* Fixed order

* Removed not implemented compression type

* missing \n

* TODO

* use iterators & change schema

* more reasonable schema names

* typo

* add compression to MoE.py

* Implemented more efficient vector compression

* Fixed errors in deserialize_tensors

* Created test for vector compression

* Fixed dtype error in deserialize_tensor

* Fixed error in serialize_tensor

* Fixed error in deserialize_tensor

* Deleted typo

* Fixed TypeError in reshape in deserialize_torch_tensor

* Fixed wrong shape in deserialize_torch_tensor

* Fixed error on changing tensor shape

* Experimentally found out the alpha for test_vector_compression

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/utils/grpc.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Changed dtype of compressed float32 tensor to compressed_float32

* Update hivemind/server/connection_handler.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* Update hivemind/server/connection_handler.py

Co-authored-by: justheuristic <justheuristic@gmail.com>

* compressed_tensor -> serialized_tensor

* Incremented version

Co-authored-by: justheuristic <justheuristic@gmail.com>
Vsevolod-pl 4 years ago
parent
commit
d4d9da9d3e

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.3'
+__version__ = '0.8.4'

+ 15 - 9
hivemind/client/expert.py

@@ -1,6 +1,6 @@
 import pickle
 import pickle
 from functools import lru_cache
 from functools import lru_cache
-from typing import Tuple, Optional, Any
+from typing import Tuple, Optional, Any, Dict
 
 
 import grpc
 import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
@@ -9,6 +9,7 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 
 
@@ -61,7 +62,7 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
 
@@ -81,14 +82,17 @@ class _RemoteModuleCall(torch.autograd.Function):
 
 
     @staticmethod
     @staticmethod
     def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
     def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
-                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+                info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.stub = uid, stub
+        ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
+        serialized_tensors = [serialize_torch_tensor(inp, proto.compression)
+                              for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))]
+
         outputs = stub.forward(
         outputs = stub.forward(
-            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
 
@@ -97,10 +101,12 @@ class _RemoteModuleCall(torch.autograd.Function):
     @staticmethod
     @staticmethod
     @once_differentiable
     @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        serialized_tensors = [serialize_torch_tensor(tensor, proto.compression)
+                              for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
 
 
-        grad_inputs = ctx.stub.backward(
-            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, *deserialized_grad_inputs)
+        return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 28 - 23
hivemind/client/moe.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import time
 import time
-from typing import Tuple, List, Optional, Awaitable, Set, Dict
+from typing import Tuple, List, Optional, Awaitable, Set, Dict, Any
 
 
 import grpc.experimental.aio
 import grpc.experimental.aio
 import torch
 import torch
@@ -58,7 +58,7 @@ class RemoteMixtureOfExperts(nn.Module):
         self.allow_broadcasting = allow_broadcasting
         self.allow_broadcasting = allow_broadcasting
 
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-        self._outputs_schema = None  # expert['info'][outputs_schema] from one of experts in the grid
+        self._expert_info = None  # expert['info'] from one of experts in the grid
 
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
         """
@@ -88,8 +88,8 @@ class RemoteMixtureOfExperts(nn.Module):
         # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
         # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
-            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
-            self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
+            self.backward_timeout, self.loop, self.info, *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -99,7 +99,7 @@ class RemoteMixtureOfExperts(nn.Module):
         averaged_outputs_flat = [
         averaged_outputs_flat = [
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
-        return nested_pack(averaged_outputs_flat, self.outputs_schema)
+        return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
 
 
     async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
     async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
         """
         """
@@ -139,9 +139,9 @@ class RemoteMixtureOfExperts(nn.Module):
             beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
             beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
             beam_experts = list(best_alive_prefixes.values())
             beam_experts = list(best_alive_prefixes.values())
 
 
-        if self._outputs_schema is None:
+        if self._expert_info is None:
             try:
             try:
-                self._outputs_schema = beam_experts[0].info['outputs_schema']
+                self._expert_info = beam_experts[0].info
             except grpc.RpcError as e:
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
 
@@ -182,15 +182,15 @@ class RemoteMixtureOfExperts(nn.Module):
         return scores
         return scores
 
 
     @property
     @property
-    def outputs_schema(self):
-        if self._outputs_schema is None:
+    def info(self):
+        if self._expert_info is None:
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
             dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
             dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
             dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
-            self._outputs_schema = dummy_experts[0].info['outputs_schema']
-        return self._outputs_schema
+            self._expert_info = dummy_experts[0].info
+        return self._expert_info
 
 
 
 
 class _RemoteCallMany(torch.autograd.Function):
 class _RemoteCallMany(torch.autograd.Function):
@@ -206,7 +206,7 @@ class _RemoteCallMany(torch.autograd.Function):
     @classmethod
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                loop: asyncio.base_events.BaseEventLoop, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
         flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
@@ -215,7 +215,7 @@ class _RemoteCallMany(torch.autograd.Function):
         async def _forward():
         async def _forward():
             # dispatch tasks to all remote experts, await responses
             # dispatch tasks to all remote experts, await responses
             pending_tasks = {
             pending_tasks = {
-                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
+                asyncio.create_task(cls._forward_one_expert((i, j), expert, info, flat_inputs_per_sample[i]))
                 for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
                 for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
             }
             }
             alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
             alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
@@ -239,7 +239,8 @@ class _RemoteCallMany(torch.autograd.Function):
 
 
             # save individual outputs for backward pass
             # save individual outputs for backward pass
             ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
             ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
-            ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+            ctx._saved_non_tensors = loop, info, backward_k_min, backward_timeout,\
+                                     timeout_after_k_min, experts_per_sample
             return (mask,) + tuple(outputs)
             return (mask,) + tuple(outputs)
 
 
         return loop.run_until_complete(_forward())
         return loop.run_until_complete(_forward())
@@ -248,7 +249,7 @@ class _RemoteCallMany(torch.autograd.Function):
     @once_differentiable
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
         assert not torch.is_grad_enabled()
-        loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        loop, info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
         alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
         dummy_grad_mask, *flat_grad_outputs = raw_grads
         dummy_grad_mask, *flat_grad_outputs = raw_grads
         num_samples, max_experts = dummy_grad_mask.shape
         num_samples, max_experts = dummy_grad_mask.shape
@@ -261,8 +262,8 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks = set()
             pending_tasks = set()
             for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
             for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                         inputs_per_expert, grad_outputs_per_expert):
                                                         inputs_per_expert, grad_outputs_per_expert):
-                pending_tasks.add(asyncio.create_task(
-                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)))
+                pending_tasks.add(asyncio.create_task(cls._backward_one_expert(
+                    (i, j), expert_per_sample[i.item()][j.item()], info, inputs_ij, grad_outputs_ij)))
 
 
             backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
             backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
                 pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
                 pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
@@ -281,28 +282,32 @@ class _RemoteCallMany(torch.autograd.Function):
 
 
                 grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
                 grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
 
 
-            return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+            return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
 
 
         return loop.run_until_complete(_backward())
         return loop.run_until_complete(_backward())
 
 
     @staticmethod
     @staticmethod
-    async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
+    async def _forward_one_expert(
+            grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any], inputs: Tuple[torch.Tensor]):
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
         try:
         try:
             outputs = await stub.forward(runtime_pb2.ExpertRequest(
             outputs = await stub.forward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression) for tensor, proto in 
+                                         zip(inputs, nested_flatten(info['forward_schema']))]))
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
             logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
 
 
     @staticmethod
     @staticmethod
-    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
+    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any],
                                    inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
                                    inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
         stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
-        payload = tuple(nested_flatten((inputs, grad_outputs)))
+        inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs)))
+        backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
         try:
         try:
             grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
             grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression)
+                                         for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]))
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
             return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
             logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")

+ 6 - 0
hivemind/proto/runtime.proto

@@ -26,10 +26,16 @@ message ExpertResponse {
   repeated Tensor tensors = 2;
   repeated Tensor tensors = 2;
 }
 }
 
 
+enum CompressionType{
+  NONE = 0;
+  MEANSTD_LAST_AXIS_FLOAT16 = 1;
+}
+
 message Tensor {
 message Tensor {
   bytes buffer = 1;
   bytes buffer = 1;
   repeated uint32 size = 2;
   repeated uint32 size = 2;
   bool requires_grad = 3;
   bool requires_grad = 3;
   string dtype = 4;
   string dtype = 4;
+  CompressionType compression = 5;
 }
 }
 
 

+ 6 - 3
hivemind/server/connection_handler.py

@@ -10,7 +10,7 @@ import uvloop
 
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -60,11 +60,14 @@ class ConnectionHandler(mp.Process):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
+        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
+                               in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))]
+
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
 
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
+        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
+                               in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
         return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 5 - 3
hivemind/server/expert_backend.py

@@ -53,9 +53,11 @@ class ExpertBackend(nn.Module):
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
 
-        self.outputs_schema = outputs_schema
-        self.forward_schema = (self.args_schema, self.kwargs_schema)
-        self.backward_schema = (self.forward_schema, self.outputs_schema)  # original inputs and grad w.r.t. outputs
+        self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
+        self.outputs_schema = outputs_schema  # outputs from forward
+
+        self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
+        self.grad_inputs_schema = self.forward_schema  # outputs from backward
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
 

+ 49 - 10
hivemind/utils/grpc.py

@@ -6,19 +6,58 @@ import numpy as np
 import torch
 import torch
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
+FP16_MAX = 65_504
+
+
+def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, 
+                           allow_inplace=False) -> runtime_pb2.Tensor:
+    if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
+        assert tensor.dtype == torch.float32
+
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+
+        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        tensor.div_(stds)
+        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
+
+        data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='compressed_float32',
+            requires_grad=tensor.requires_grad)
+    else:
+        array = tensor.numpy()
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad)
 
 
-def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
-    array = tensor.numpy()
-    proto = runtime_pb2.Tensor(
-        buffer=array.tobytes(),
-        size=array.shape,
-        dtype=array.dtype.name,
-        requires_grad=tensor.requires_grad)
     return proto
     return proto
 
 
 
 
-def deserialize_torch_tensor(tensor: runtime_pb2.Tensor) -> torch.Tensor:
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
     # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
-    array = np.frombuffer(tensor.buffer, dtype=np.dtype(tensor.dtype)).copy()
-    return torch.as_tensor(array).view(tuple(tensor.size)).requires_grad_(tensor.requires_grad)
+    if serialized_tensor.compression == CompressionType.NONE:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
+        tensor = torch.as_tensor(array).view(*serialized_tensor.size).requires_grad_(serialized_tensor.requires_grad)
+    elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
+        stats_size = list(serialized_tensor.size)
+        stats_size[-1] = 1
+        stats_count = np.prod(stats_size)
+        means, stds = serialized_tensor.buffer[-8*stats_count:-4*stats_count], serialized_tensor.buffer[-4*stats_count:]
+        means = torch.as_tensor(np.frombuffer(means, dtype=np.float32)).view(*stats_size)
+        stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
+        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
+        tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
+    else:
+        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
+    return tensor

+ 3 - 0
hivemind/utils/tensor_descr.py

@@ -2,6 +2,8 @@ from dataclasses import dataclass, asdict
 
 
 import torch
 import torch
 
 
+from hivemind.proto.runtime_pb2 import CompressionType
+
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
 
 
@@ -18,6 +20,7 @@ class TensorDescriptor(DescriptorBase):
     device: torch.device = None
     device: torch.device = None
     requires_grad: bool = False
     requires_grad: bool = False
     pin_memory: bool = False
     pin_memory: bool = False
+    compression: CompressionType = CompressionType.NONE
 
 
     @property
     @property
     def shape(self):
     def shape(self):

+ 1 - 1
tests/test_moe.py

@@ -45,7 +45,7 @@ def test_call_many():
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
             k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
-            asyncio.new_event_loop(), inputs
+            asyncio.new_event_loop(), e1.info, inputs
         )
         )
         assert mask.shape == (4, 3)
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, 64)
         assert expert_outputs.shape == (4, 3, 64)

+ 16 - 0
tests/test_util_modules.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import torch
 
 
 import pytest
 import pytest
 import hivemind
 import hivemind
@@ -81,6 +82,7 @@ def test_await_mpfuture():
     async def _run():
     async def _run():
         # await result
         # await result
         f1, f2 = hivemind.MPFuture.make_pair()
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_assign():
         async def wait_and_assign():
             assert f2.set_running_or_notify_cancel() is True
             assert f2.set_running_or_notify_cancel() is True
             await asyncio.sleep(0.1)
             await asyncio.sleep(0.1)
@@ -93,6 +95,7 @@ def test_await_mpfuture():
 
 
         # await cancel
         # await cancel
         f1, f2 = hivemind.MPFuture.make_pair()
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_cancel():
         async def wait_and_cancel():
             await asyncio.sleep(0.1)
             await asyncio.sleep(0.1)
             f1.cancel()
             f1.cancel()
@@ -104,6 +107,7 @@ def test_await_mpfuture():
 
 
         # await exception
         # await exception
         f1, f2 = hivemind.MPFuture.make_pair()
         f1, f2 = hivemind.MPFuture.make_pair()
+
         async def wait_and_raise():
         async def wait_and_raise():
             await asyncio.sleep(0.1)
             await asyncio.sleep(0.1)
             f1.set_exception(SystemError())
             f1.set_exception(SystemError())
@@ -114,3 +118,15 @@ def test_await_mpfuture():
                 await future
                 await future
 
 
     asyncio.new_event_loop().run_until_complete(_run())
     asyncio.new_event_loop().run_until_complete(_run())
+
+
+def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
+    torch.manual_seed(0)
+    from hivemind.proto.runtime_pb2 import CompressionType
+    from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
+    assert error.square().mean() < alpha
+    return error.square().mean()
+