grpc.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """
  2. Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
  3. """
  4. from __future__ import annotations
  5. import os
  6. import threading
  7. from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
  8. import grpc
  9. from hivemind.proto import runtime_pb2
  10. from hivemind.utils.logging import get_logger
  11. from hivemind.utils.networking import Endpoint
  12. from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
  13. logger = get_logger(__name__)
  14. Stub = TypeVar("Stub")
  15. GRPC_KEEPALIVE_OPTIONS = (
  16. ("grpc.keepalive_time_ms", 60 * 1000),
  17. ("grpc.keepalive_timeout_ms", 60 * 1000),
  18. ("grpc.keepalive_permit_without_calls", True),
  19. ("grpc.http2.max_pings_without_data", 0),
  20. ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
  21. ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
  22. )
  23. class ChannelInfo(NamedTuple):
  24. target: Endpoint
  25. aio: bool
  26. options: Tuple[Tuple[str, str], ...]
  27. credentials: Optional[grpc.ChannelCredentials]
  28. compression: Optional[grpc.Compression]
  29. class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
  30. """
  31. A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
  32. Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
  33. Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
  34. Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
  35. """
  36. MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
  37. EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
  38. logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
  39. _singleton: Optional[ChannelCache] = None
  40. _singleton_pid: int = os.getpid()
  41. _lock: threading.RLock = threading.RLock()
  42. _update_eviction_evt: threading.Event = threading.Event()
  43. def __init__(self, _created_as_singleton=False):
  44. assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
  45. super().__init__(maxsize=self.MAXIMUM_CHANNELS)
  46. self._is_active = True
  47. self._nearest_expiration_time = float("inf")
  48. self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
  49. self._eviction_thread.start()
  50. @classmethod
  51. def get_singleton(cls):
  52. """Get or create the channel cache for the current process"""
  53. with cls._lock:
  54. if cls._singleton is None or cls._singleton_pid != os.getpid():
  55. if cls._singleton is not None:
  56. cls._singleton._stop_background_thread()
  57. cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
  58. return cls._singleton
  59. @classmethod
  60. def get_stub(
  61. cls,
  62. target: Endpoint,
  63. stub_type: Type[Stub],
  64. *,
  65. aio: bool,
  66. options: Tuple[Tuple[str, Any]] = (),
  67. channel_credentials: Optional[grpc.ChannelCredentials] = None,
  68. compression: Optional[grpc.Compression] = None,
  69. ) -> Stub:
  70. """
  71. Create a grpc channel with given options or reuse pre-existing one
  72. :param target: the recipient's address and port
  73. :param stub_type: a gRPC stub (client) to be instantiated
  74. :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
  75. :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
  76. :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
  77. :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
  78. """
  79. cache = cls.get_singleton()
  80. with cls._lock:
  81. key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
  82. entry: ValueWithExpiration = super(cls, cache).get(key)
  83. if entry is not None:
  84. channel, stubs = entry.value
  85. else:
  86. channel = cls._create_channel(*key)
  87. stubs = {}
  88. channel._channel.check_connectivity_state(True)
  89. if stub_type not in stubs:
  90. stubs[stub_type] = stub_type(channel)
  91. # either cache channel or update expiration of an existing channel
  92. expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
  93. super(cls, cache).store(key, (channel, stubs), expiration_time)
  94. if expiration_time < cache._nearest_expiration_time:
  95. cache._nearest_expiration_time = expiration_time
  96. cls._update_eviction_evt.set()
  97. return stubs[stub_type]
  98. @classmethod
  99. def _create_channel(
  100. cls,
  101. target: Endpoint,
  102. aio: bool,
  103. extra_options: Tuple[Tuple[str, Any], ...],
  104. channel_credentials: Optional[grpc.ChannelCredentials],
  105. compression: Optional[grpc.Compression],
  106. ) -> Union[grpc.Channel, grpc.aio.Channel]:
  107. namespace = grpc.aio if aio else grpc
  108. options = extra_options + GRPC_KEEPALIVE_OPTIONS
  109. if channel_credentials is None:
  110. logger.debug(
  111. f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
  112. )
  113. return namespace.insecure_channel(target, options=options, compression=compression)
  114. else:
  115. logger.debug(
  116. f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
  117. f"options '{options}' and compression '{compression}'"
  118. )
  119. return namespace.secure_channel(
  120. target, credentials=channel_credentials, options=options, compression=compression
  121. )
  122. def _evict_stale_channels_in_background(self):
  123. while self._is_active:
  124. now = get_dht_time()
  125. time_to_wait = max(0.0, self._nearest_expiration_time - now)
  126. interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
  127. if interrupted_early:
  128. self._update_eviction_evt.clear()
  129. continue
  130. with self._lock:
  131. self._remove_outdated()
  132. _, entry = super().top()
  133. self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
  134. def _stop_background_thread(self):
  135. with self._lock:
  136. self._is_active = False
  137. self._update_eviction_evt.set()
  138. def store(self, *args, **kwargs) -> ValueError:
  139. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  140. def get(self, *args, **kwargs) -> ValueError:
  141. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  142. def top(self) -> ValueError:
  143. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  144. STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
  145. def split_for_streaming(
  146. serialized_tensor: runtime_pb2.Tensor,
  147. chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
  148. ) -> Iterator[runtime_pb2.Tensor]:
  149. """Split serialized_tensor into multiple chunks for gRPC streaming"""
  150. buffer = memoryview(serialized_tensor.buffer)
  151. num_chunks = len(range(0, len(buffer), chunk_size_bytes))
  152. yield runtime_pb2.Tensor(
  153. compression=serialized_tensor.compression,
  154. buffer=buffer[:chunk_size_bytes].tobytes(),
  155. chunks=num_chunks,
  156. size=serialized_tensor.size,
  157. dtype=serialized_tensor.dtype,
  158. requires_grad=serialized_tensor.requires_grad,
  159. )
  160. for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
  161. yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
  162. def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
  163. """Restore a result of split_into_chunks into a single serialized tensor"""
  164. stream = iter(stream)
  165. first_chunk = next(stream)
  166. serialized_tensor = runtime_pb2.Tensor()
  167. serialized_tensor.CopyFrom(first_chunk)
  168. buffer_chunks = [first_chunk.buffer]
  169. for tensor_part in stream:
  170. buffer_chunks.append(tensor_part.buffer)
  171. serialized_tensor.buffer = b"".join(buffer_chunks)
  172. return serialized_tensor