Selaa lähdekoodia

Server-side gRPC keepalive (#131)

* Server-side gRPC keepalive
Max Ryabinin 4 vuotta sitten
vanhempi
commit
85a99730af

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.17'
+__version__ = '0.8.18'

+ 2 - 1
hivemind/client/averager.py

@@ -15,6 +15,7 @@ import grpc
 import hivemind
 from hivemind.dht import get_dht_time, DHTExpiration
 from hivemind.utils import get_logger, Endpoint, Port, MPFuture
+from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 from hivemind.client.allreduce import GroupAllReduce, GroupID
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 
@@ -90,7 +91,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         async def _run():
             if listen:
                 grpc.aio.init_grpc_aio()
-                server = grpc.aio.server(**server_kwargs)
+                server = grpc.aio.server(**server_kwargs, options=GRPC_KEEPALIVE_OPTIONS)
                 averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
                 found_port = server.add_insecure_port(listen_on)
                 assert found_port != 0, f"Failed to listen to {listen_on}"

+ 2 - 1
hivemind/dht/protocol.py

@@ -10,6 +10,7 @@ from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpirat
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
+from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)
 
@@ -50,7 +51,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
         if listen:  # set up server to process incoming rpc requests
             grpc.aio.init_grpc_aio()
-            self.server = grpc.aio.server(**kwargs)
+            self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
 
             found_port = self.server.add_insecure_port(listen_on)

+ 3 - 2
hivemind/server/connection_handler.py

@@ -11,6 +11,7 @@ import uvloop
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
+from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)
 
@@ -37,11 +38,11 @@ class ConnectionHandler(mp.Process):
         async def _run():
             grpc.aio.init_grpc_aio()
             logger.debug(f'Starting, pid {os.getpid()}')
-            server = grpc.aio.server(options=[
+            server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS + (
                 ('grpc.so_reuseport', 1),
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_receive_message_length', -1)
-            ])
+            ))
             runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
 
             found_port = server.add_insecure_port(self.listen_on)

+ 10 - 8
hivemind/utils/grpc.py

@@ -20,6 +20,15 @@ logger = get_logger(__file__)
 
 Stub = TypeVar("Stub")
 
+GRPC_KEEPALIVE_OPTIONS = (
+    ('grpc.keepalive_time_ms', 60 * 1000),
+    ('grpc.keepalive_timeout_ms', 60 * 1000),
+    ('grpc.keepalive_permit_without_calls', True),
+    ('grpc.http2.max_pings_without_data', 0),
+    ('grpc.http2.min_time_between_pings_ms', 30 * 1000),
+    ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
+)
+
 
 class ChannelInfo(NamedTuple):
     target: Endpoint
@@ -109,14 +118,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
                         compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
         namespace = grpc.aio if aio else grpc
 
-        options = extra_options + (
-            ('grpc.keepalive_time_ms', 60 * 1000),
-            ('grpc.keepalive_timeout_ms', 60 * 1000),
-            ('grpc.keepalive_permit_without_calls', True),
-            ('grpc.http2.max_pings_without_data', 0),
-            ('grpc.http2.min_time_between_pings_ms', 30 * 1000),
-            ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
-        )
+        options = extra_options + GRPC_KEEPALIVE_OPTIONS
 
         if channel_credentials is None:
             logger.debug(f"Creating insecure {namespace} channel with options '{options}' "