Browse Source

Compile protobuf in setup.py (#85)

* Compile protobuf in setup.py

* Update circleci pipelines

* Update RTD pipeline

* Refactor custom build_ext into install and develop
Max Ryabinin 5 năm trước cách đây
mục cha
commit
e7840e337b

+ 2 - 2
.circleci/config.yml

@@ -9,11 +9,11 @@ jobs:
     steps:
       - checkout
       - python/load-cache
-      - run: sudo pip install codecov pytest tqdm scikit-learn
+      - run: pip install codecov pytest tqdm scikit-learn
       - python/install-deps
       - python/save-cache
       - run:
-          command: sudo python setup.py develop
+          command: pip install -e .
           name: setup
       - run:
           command: pytest ./tests

+ 7 - 2
.readthedocs.yml

@@ -3,5 +3,10 @@ version: 2
 sphinx:
   fail_on_warning: true
 
-conda:
-  environment: docs/environment.yaml
+python:
+  version: 3.7
+  install:
+    - requirements: requirements.txt
+    - requirements: docs/requirements.txt
+    - method: pip
+      path: .

+ 0 - 1
docs/conf.py

@@ -22,7 +22,6 @@ from recommonmark.parser import CommonMarkParser
 
 
 # -- Project information -----------------------------------------------------
-sys.path.insert(0, '..')
 src_path = '../hivemind'
 project = 'hivemind'
 copyright = '2020, Learning@home & contributors'

+ 0 - 19
docs/environment.yaml

@@ -1,19 +0,0 @@
-channels:
-  - defaults
-  - anaconda
-  - pytorch
-  - conda-forge
-dependencies:
-  - grpcio
-  - grpcio-tools
-  - numpy>=1.14
-  - pytorch>=1.3.0
-  - joblib>=0.13
-  - pip
-  - pip:
-    - recommonmark
-    - sphinx_rtd_theme
-    - prefetch_generator>=1.0.1
-    - uvloop>=0.14.0
-    - umsgpack
-

+ 2 - 0
docs/requirements.txt

@@ -0,0 +1,2 @@
+recommonmark
+sphinx_rtd_theme

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.7.1'
+__version__ = '0.8.0'

+ 2 - 1
hivemind/client/expert.py

@@ -8,8 +8,9 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 

+ 5 - 4
hivemind/client/moe.py

@@ -1,17 +1,18 @@
 from __future__ import annotations
-import time
+
 import asyncio
+import time
 from typing import Tuple, List, Optional, Awaitable, Set, Dict
 
+import grpc.experimental.aio
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
-import grpc.experimental.aio
 
 import hivemind
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
-    serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 2 - 5
hivemind/dht/protocol.py

@@ -3,7 +3,6 @@ from __future__ import annotations
 
 import asyncio
 import heapq
-import os
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 
@@ -11,13 +10,11 @@ import grpc
 import grpc.experimental.aio
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port
+from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
+from hivemind.utils import Endpoint, get_logger, replace_port
 
 logger = get_logger(__name__)
 
-with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
-    dht_pb2, dht_grpc = compile_grpc(f_proto.read())
-
 
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off

+ 0 - 0
hivemind/dht/dht.proto → hivemind/proto/dht.proto


+ 0 - 0
hivemind/server/connection_handler.proto → hivemind/proto/runtime.proto


+ 2 - 1
hivemind/server/connection_handler.py

@@ -8,8 +8,9 @@ import grpc.experimental.aio
 import torch
 import uvloop
 
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint
 
 logger = get_logger(__name__)
 

+ 1 - 45
hivemind/utils/grpc.py

@@ -1,55 +1,11 @@
 """
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
-import functools
-import os
-import sys
-import tempfile
-from argparse import Namespace
-from typing import Tuple
 
-import grpc_tools.protoc
 import numpy as np
 import torch
 
-
-@functools.lru_cache(maxsize=None)
-def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
-    """
-    Compiles and loads grpc protocol defined by protobuf string
-
-    :param proto: protocol buffer code as a string, as in open('file.proto').read()
-    :param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath'
-    :returns: messages, services protobuf
-    """
-    base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto')
-
-    with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir:
-        proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir)
-        with open(proto_path, 'w') as fproto:
-            fproto.write(proto)
-
-        cli_args = (
-            grpc_tools.protoc.__file__, f"-I{base_include}",
-            f"--python_out={build_dir}", f"--grpc_python_out={build_dir}",
-            f"-I{build_dir}", *args, os.path.basename(proto_path))
-        code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args])
-        if code:  # hint: if you get this error in jupyter, run in console for richer error message
-            raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
-
-        try:
-            sys.path.append(build_dir)
-            pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2'
-            messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc')
-            return messages, services
-        finally:
-            if sys.path.pop() != build_dir:
-                raise ImportError("Something changed sys.path while compile_grpc was in progress.")
-
-
-with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
-                       'server', 'connection_handler.proto')) as f_proto:
-    runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
+from hivemind.proto import runtime_pb2
 
 
 def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:

+ 42 - 4
setup.py

@@ -1,8 +1,43 @@
-from pkg_resources import parse_requirements
-from setuptools import setup
 import codecs
-import re
+import glob
 import os
+import re
+
+import grpc_tools.protoc
+from pkg_resources import parse_requirements
+from setuptools import setup, find_packages
+from setuptools.command.develop import develop
+from setuptools.command.install import install
+
+
+def proto_compile(output_path):
+    cli_args = ['grpc_tools.protoc',
+                '--proto_path=hivemind/proto', f'--python_out={output_path}',
+                f'--grpc_python_out={output_path}'] + glob.glob('hivemind/proto/*.proto')
+
+    code = grpc_tools.protoc.main(cli_args)
+    if code:  # hint: if you get this error in jupyter, run in console for richer error message
+        raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
+    # Make pb2 imports in generated scripts relative
+    for script in glob.iglob(f'{output_path}/*.py'):
+        with open(script, 'r+') as file:
+            code = file.read()
+            file.seek(0)
+            file.write(re.sub(r'\n(import .+_pb2.*)', 'from . \\1', code))
+            file.truncate()
+
+
+class ProtoCompileInstall(install):
+    def run(self):
+        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+        super().run()
+
+
+class ProtoCompileDevelop(develop):
+    def run(self):
+        proto_compile(os.path.join('hivemind', 'proto'))
+        super().run()
+
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -17,12 +52,15 @@ with codecs.open(os.path.join(here, 'hivemind/__init__.py'), encoding='utf-8') a
 setup(
     name='hivemind',
     version=version_string,
+    cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
     description='',
     long_description='',
     author='Learning@home authors',
     author_email='mryabinin@hse.ru',
     url="https://github.com/learning-at-home/hivemind",
-    packages=['hivemind'],
+    packages=find_packages(exclude=['tests']),
+    package_data={'hivemind': ['proto/*']},
+    include_package_data=True,
     license='MIT',
     install_requires=install_requires,
     classifiers=[