소스 검색

Compile protobuf in setup.py (#85)

* Compile protobuf in setup.py

* Update circleci pipelines

* Update RTD pipeline

* Refactor custom build_ext into install and develop
Max Ryabinin 5 년 전
부모
커밋
e7840e337b

+ 2 - 2
.circleci/config.yml

@@ -9,11 +9,11 @@ jobs:
     steps:
     steps:
       - checkout
       - checkout
       - python/load-cache
       - python/load-cache
-      - run: sudo pip install codecov pytest tqdm scikit-learn
+      - run: pip install codecov pytest tqdm scikit-learn
       - python/install-deps
       - python/install-deps
       - python/save-cache
       - python/save-cache
       - run:
       - run:
-          command: sudo python setup.py develop
+          command: pip install -e .
           name: setup
           name: setup
       - run:
       - run:
           command: pytest ./tests
           command: pytest ./tests

+ 7 - 2
.readthedocs.yml

@@ -3,5 +3,10 @@ version: 2
 sphinx:
 sphinx:
   fail_on_warning: true
   fail_on_warning: true
 
 
-conda:
-  environment: docs/environment.yaml
+python:
+  version: 3.7
+  install:
+    - requirements: requirements.txt
+    - requirements: docs/requirements.txt
+    - method: pip
+      path: .

+ 0 - 1
docs/conf.py

@@ -22,7 +22,6 @@ from recommonmark.parser import CommonMarkParser
 
 
 
 
 # -- Project information -----------------------------------------------------
 # -- Project information -----------------------------------------------------
-sys.path.insert(0, '..')
 src_path = '../hivemind'
 src_path = '../hivemind'
 project = 'hivemind'
 project = 'hivemind'
 copyright = '2020, Learning@home & contributors'
 copyright = '2020, Learning@home & contributors'

+ 0 - 19
docs/environment.yaml

@@ -1,19 +0,0 @@
-channels:
-  - defaults
-  - anaconda
-  - pytorch
-  - conda-forge
-dependencies:
-  - grpcio
-  - grpcio-tools
-  - numpy>=1.14
-  - pytorch>=1.3.0
-  - joblib>=0.13
-  - pip
-  - pip:
-    - recommonmark
-    - sphinx_rtd_theme
-    - prefetch_generator>=1.0.1
-    - uvloop>=0.14.0
-    - umsgpack
-

+ 2 - 0
docs/requirements.txt

@@ -0,0 +1,2 @@
+recommonmark
+sphinx_rtd_theme

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.7.1'
+__version__ = '0.8.0'

+ 2 - 1
hivemind/client/expert.py

@@ -8,8 +8,9 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 

+ 5 - 4
hivemind/client/moe.py

@@ -1,17 +1,18 @@
 from __future__ import annotations
 from __future__ import annotations
-import time
+
 import asyncio
 import asyncio
+import time
 from typing import Tuple, List, Optional, Awaitable, Set, Dict
 from typing import Tuple, List, Optional, Awaitable, Set, Dict
 
 
+import grpc.experimental.aio
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
-import grpc.experimental.aio
 
 
 import hivemind
 import hivemind
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
-    serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 2 - 5
hivemind/dht/protocol.py

@@ -3,7 +3,6 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import heapq
 import heapq
-import os
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 from warnings import warn
 
 
@@ -11,13 +10,11 @@ import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
 
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port
+from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
+from hivemind.utils import Endpoint, get_logger, replace_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
-    dht_pb2, dht_grpc = compile_grpc(f_proto.read())
-
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off

+ 0 - 0
hivemind/dht/dht.proto → hivemind/proto/dht.proto


+ 0 - 0
hivemind/server/connection_handler.proto → hivemind/proto/runtime.proto


+ 2 - 1
hivemind/server/connection_handler.py

@@ -8,8 +8,9 @@ import grpc.experimental.aio
 import torch
 import torch
 import uvloop
 import uvloop
 
 
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 45
hivemind/utils/grpc.py

@@ -1,55 +1,11 @@
 """
 """
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
 """
-import functools
-import os
-import sys
-import tempfile
-from argparse import Namespace
-from typing import Tuple
 
 
-import grpc_tools.protoc
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-
-@functools.lru_cache(maxsize=None)
-def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
-    """
-    Compiles and loads grpc protocol defined by protobuf string
-
-    :param proto: protocol buffer code as a string, as in open('file.proto').read()
-    :param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath'
-    :returns: messages, services protobuf
-    """
-    base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto')
-
-    with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir:
-        proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir)
-        with open(proto_path, 'w') as fproto:
-            fproto.write(proto)
-
-        cli_args = (
-            grpc_tools.protoc.__file__, f"-I{base_include}",
-            f"--python_out={build_dir}", f"--grpc_python_out={build_dir}",
-            f"-I{build_dir}", *args, os.path.basename(proto_path))
-        code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args])
-        if code:  # hint: if you get this error in jupyter, run in console for richer error message
-            raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
-
-        try:
-            sys.path.append(build_dir)
-            pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2'
-            messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc')
-            return messages, services
-        finally:
-            if sys.path.pop() != build_dir:
-                raise ImportError("Something changed sys.path while compile_grpc was in progress.")
-
-
-with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
-                       'server', 'connection_handler.proto')) as f_proto:
-    runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
+from hivemind.proto import runtime_pb2
 
 
 
 
 def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
 def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:

+ 42 - 4
setup.py

@@ -1,8 +1,43 @@
-from pkg_resources import parse_requirements
-from setuptools import setup
 import codecs
 import codecs
-import re
+import glob
 import os
 import os
+import re
+
+import grpc_tools.protoc
+from pkg_resources import parse_requirements
+from setuptools import setup, find_packages
+from setuptools.command.develop import develop
+from setuptools.command.install import install
+
+
+def proto_compile(output_path):
+    cli_args = ['grpc_tools.protoc',
+                '--proto_path=hivemind/proto', f'--python_out={output_path}',
+                f'--grpc_python_out={output_path}'] + glob.glob('hivemind/proto/*.proto')
+
+    code = grpc_tools.protoc.main(cli_args)
+    if code:  # hint: if you get this error in jupyter, run in console for richer error message
+        raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
+    # Make pb2 imports in generated scripts relative
+    for script in glob.iglob(f'{output_path}/*.py'):
+        with open(script, 'r+') as file:
+            code = file.read()
+            file.seek(0)
+            file.write(re.sub(r'\n(import .+_pb2.*)', 'from . \\1', code))
+            file.truncate()
+
+
+class ProtoCompileInstall(install):
+    def run(self):
+        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+        super().run()
+
+
+class ProtoCompileDevelop(develop):
+    def run(self):
+        proto_compile(os.path.join('hivemind', 'proto'))
+        super().run()
+
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))
 
 
@@ -17,12 +52,15 @@ with codecs.open(os.path.join(here, 'hivemind/__init__.py'), encoding='utf-8') a
 setup(
 setup(
     name='hivemind',
     name='hivemind',
     version=version_string,
     version=version_string,
+    cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
     description='',
     description='',
     long_description='',
     long_description='',
     author='Learning@home authors',
     author='Learning@home authors',
     author_email='mryabinin@hse.ru',
     author_email='mryabinin@hse.ru',
     url="https://github.com/learning-at-home/hivemind",
     url="https://github.com/learning-at-home/hivemind",
-    packages=['hivemind'],
+    packages=find_packages(exclude=['tests']),
+    package_data={'hivemind': ['proto/*']},
+    include_package_data=True,
     license='MIT',
     license='MIT',
     install_requires=install_requires,
     install_requires=install_requires,
     classifiers=[
     classifiers=[