|
@@ -4,7 +4,7 @@ Utilities for running GRPC services: compile protobuf, patch legacy versions, et
|
|
|
from __future__ import annotations
|
|
|
import os
|
|
|
import threading
|
|
|
-from typing import NamedTuple, Sequence, Tuple, Optional, Union, Any, Dict, TypeVar, Type
|
|
|
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type
|
|
|
|
|
|
import grpc
|
|
|
import numpy as np
|
|
@@ -12,7 +12,7 @@ import torch
|
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, DHTExpiration, ValueWithExpiration
|
|
|
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
|
|
|
from hivemind.utils.networking import Endpoint
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
@@ -64,7 +64,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
|
|
|
return cls._singleton
|
|
|
|
|
|
@classmethod
|
|
|
- def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Sequence[Tuple[str, Any]] = (),
|
|
|
+ def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Tuple[Tuple[str, Any]] = (),
|
|
|
channel_credentials: Optional[grpc.ChannelCredentials] = None,
|
|
|
compression: Optional[grpc.Compression] = None) -> Stub:
|
|
|
"""
|
|
@@ -79,9 +79,17 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
|
|
|
"""
|
|
|
cache = cls.get_singleton()
|
|
|
with cls._lock:
|
|
|
- key = ChannelInfo(target, aio, tuple(options or ()), channel_credentials, compression)
|
|
|
+ key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
|
|
|
entry: ValueWithExpiration = super(cls, cache).get(key)
|
|
|
- channel, stubs = entry.value if entry is not None else (cls._create_channel(*key), {})
|
|
|
+
|
|
|
+ if entry is not None:
|
|
|
+ channel, stubs = entry.value
|
|
|
+ else:
|
|
|
+ channel = cls._create_channel(*key)
|
|
|
+ stubs = {}
|
|
|
+
|
|
|
+ channel._channel.check_connectivity_state(True)
|
|
|
+
|
|
|
if stub_type not in stubs:
|
|
|
stubs[stub_type] = stub_type(channel)
|
|
|
|
|
@@ -96,10 +104,20 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
|
|
|
return stubs[stub_type]
|
|
|
|
|
|
@classmethod
|
|
|
- def _create_channel(cls, target: Endpoint, aio: bool, options: Sequence[Tuple[str, Any], ...],
|
|
|
+ def _create_channel(cls, target: Endpoint, aio: bool, extra_options: Tuple[Tuple[str, Any], ...],
|
|
|
channel_credentials: Optional[grpc.ChannelCredentials],
|
|
|
compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
|
|
|
namespace = grpc.aio if aio else grpc
|
|
|
+
|
|
|
+ options = extra_options + (
|
|
|
+ ('grpc.keepalive_time_ms', 60 * 1000),
|
|
|
+ ('grpc.keepalive_timeout_ms', 60 * 1000),
|
|
|
+ ('grpc.keepalive_permit_without_calls', True),
|
|
|
+ ('grpc.http2.max_pings_without_data', 0),
|
|
|
+ ('grpc.http2.min_time_between_pings_ms', 30 * 1000),
|
|
|
+ ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
|
|
|
+ )
|
|
|
+
|
|
|
if channel_credentials is None:
|
|
|
logger.debug(f"Creating insecure {namespace} channel with options '{options}' "
|
|
|
f"and compression '{compression}'")
|