浏览代码

Add compression parameter to server (#112)

* Added parameter compression to Server.create

* Added parameter compression to run_server

* Added better compression parameter parsing
Vsevolod-pl 4 年之前
父节点
当前提交
f55005cd86
共有 3 个文件被更改,包括 22 次插入7 次删除
  1. 6 4
      hivemind/server/__init__.py
  2. 2 2
      hivemind/utils/tensor_descr.py
  3. 14 1
      scripts/run_server.py

+ 6 - 4
hivemind/server/__init__.py

@@ -20,6 +20,7 @@ from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
+from hivemind.proto.runtime_pb2 import CompressionType
 
 logger = get_logger(__name__)
 
@@ -69,7 +70,7 @@ class Server(threading.Thread):
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
                device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
-               *, start: bool, **kwargs) -> Server:
+               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -118,9 +119,9 @@ class Server(threading.Thread):
 
         sample_input = name_to_input[expert_cls](4, hidden_dim)
         if isinstance(sample_input, tuple):
-            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
+            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
         else:
-            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
+            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
         # initialize experts
 
@@ -129,7 +130,8 @@ class Server(threading.Thread):
             expert = name_to_block[expert_cls](hidden_dim)
             experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                          args_schema=args_schema,
-                                                         outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
+                                                         outputs_schema=hivemind.BatchTensorDescriptor(
+                                                             hidden_dim, compression=compression),
                                                          opt=optim_cls(expert.parameters()),
                                                          max_batch_size=max_batch_size,
                                                          )

+ 2 - 2
hivemind/utils/tensor_descr.py

@@ -46,10 +46,10 @@ class BatchTensorDescriptor(TensorDescriptor):
         super().__init__((None, *instance_size), **kwargs)
 
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor):
+    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
                    device=tensor.device, requires_grad=tensor.requires_grad,
-                   pin_memory=torch.cuda.is_available() and tensor.is_pinned())
+                   pin_memory=torch.cuda.is_available() and tensor.is_pinned(), compression=compression)
 
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 14 - 1
scripts/run_server.py

@@ -7,6 +7,8 @@ import torch
 
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
+from hivemind.proto.runtime_pb2 import CompressionType
+
 
 
 def main():
@@ -39,6 +41,8 @@ def main():
     parser.add_argument('--increase_file_limit', action='store_true',
                         help='On *nix, this will increase the max number of processes '
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
+                        'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
     # fmt:on
     args = vars(parser.parse_args())
     args.pop('config', None)
@@ -55,8 +59,17 @@ def main():
     if args.pop('increase_file_limit'):
         increase_file_limit()
 
+    compression_name = args.pop("compression")
+    compression = CompressionType.NONE
+    if compression_name == "MEANSTD":
+        compression = CompressionType.MEANSTD_LAST_AXIS_FLOAT16
+    elif compression_name == "FLOAT16":
+        compression = CompressionType.FLOAT16
+    else:
+        compression = getattr(CompressionType, compression_name)
+
     try:
-        server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True)
+        server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True, compression=compression)
         server.join()
     finally:
         server.shutdown()