瀏覽代碼

gRPC connection keepalive (#129)

* Add gRPC keepalive options to create_channel

* Bump version

* Style fixes
Max Ryabinin 4 年之前
父節點
當前提交
88122eed1b
共有 5 個文件被更改,包括 31 次插入13 次删除
  1. 1 1
      hivemind/__init__.py
  2. 2 2
      hivemind/dht/protocol.py
  3. 2 2
      hivemind/dht/traverse.py
  4. 24 6
      hivemind/utils/grpc.py
  5. 2 2
      tests/test_averaging.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.15'
+__version__ = '0.8.16'

+ 2 - 2
hivemind/dht/protocol.py

@@ -17,7 +17,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.aio.Server
+    channel_options: Sequence[Tuple[str, Any]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
 
@@ -28,7 +28,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
-            channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
+            channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         As a side-effect, DHTProtocol also maintains a routing table as described in

+ 2 - 2
hivemind/dht/traverse.py

@@ -31,7 +31,7 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
     visited_nodes = set(visited_nodes)  # note: copy visited_nodes because we will add more nodes to this collection.
     initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
     if not initial_nodes:
-        return [], visited_nodes
+        return (), visited_nodes
 
     unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
     heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size
@@ -59,7 +59,7 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
                 heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
                 upper_bound = -nearest_nodes[0][0]  # distance to beam_size-th nearest element found so far
 
-    return [node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes)], visited_nodes
+    return tuple(node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes)), visited_nodes
 
 
 async def traverse_dht(

+ 24 - 6
hivemind/utils/grpc.py

@@ -4,7 +4,7 @@ Utilities for running GRPC services: compile protobuf, patch legacy versions, et
 from __future__ import annotations
 import os
 import threading
-from typing import NamedTuple, Sequence, Tuple, Optional, Union, Any, Dict, TypeVar, Type
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type
 
 import grpc
 import numpy as np
@@ -12,7 +12,7 @@ import torch
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, DHTExpiration, ValueWithExpiration
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.logging import get_logger
 
@@ -64,7 +64,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             return cls._singleton
 
     @classmethod
-    def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Sequence[Tuple[str, Any]] = (),
+    def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Tuple[Tuple[str, Any]] = (),
                  channel_credentials: Optional[grpc.ChannelCredentials] = None,
                  compression: Optional[grpc.Compression] = None) -> Stub:
         """
@@ -79,9 +79,17 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         """
         cache = cls.get_singleton()
         with cls._lock:
-            key = ChannelInfo(target, aio, tuple(options or ()), channel_credentials, compression)
+            key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
             entry: ValueWithExpiration = super(cls, cache).get(key)
-            channel, stubs = entry.value if entry is not None else (cls._create_channel(*key), {})
+
+            if entry is not None:
+                channel, stubs = entry.value
+            else:
+                channel = cls._create_channel(*key)
+                stubs = {}
+
+            channel._channel.check_connectivity_state(True)
+
             if stub_type not in stubs:
                 stubs[stub_type] = stub_type(channel)
 
@@ -96,10 +104,20 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             return stubs[stub_type]
 
     @classmethod
-    def _create_channel(cls, target: Endpoint, aio: bool, options: Sequence[Tuple[str, Any], ...],
+    def _create_channel(cls, target: Endpoint, aio: bool, extra_options: Tuple[Tuple[str, Any], ...],
                         channel_credentials: Optional[grpc.ChannelCredentials],
                         compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
         namespace = grpc.aio if aio else grpc
+
+        options = extra_options + (
+            ('grpc.keepalive_time_ms', 60 * 1000),
+            ('grpc.keepalive_timeout_ms', 60 * 1000),
+            ('grpc.keepalive_permit_without_calls', True),
+            ('grpc.http2.max_pings_without_data', 0),
+            ('grpc.http2.min_time_between_pings_ms', 30 * 1000),
+            ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
+        )
+
         if channel_credentials is None:
             logger.debug(f"Creating insecure {namespace} channel with options '{options}' "
                          f"and compression '{compression}'")

+ 2 - 2
tests/test_averaging.py

@@ -97,9 +97,9 @@ async def test_allreduce_protocol():
 
 @pytest.mark.forked
 def test_chunks():
-    for i in range(100):
+    for _ in range(100):
         tensors = []
-        for i in range(random.randint(1, 5)):
+        for _ in range(random.randint(1, 5)):
             ndim = random.randint(0, 4)
             shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
             make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])