grpc.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """
  2. Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
  3. """
  4. from __future__ import annotations
  5. import os
  6. import threading
  7. from typing import (
  8. Any,
  9. AsyncIterator,
  10. Callable,
  11. Dict,
  12. Iterable,
  13. Iterator,
  14. List,
  15. NamedTuple,
  16. Optional,
  17. Tuple,
  18. Type,
  19. TypeVar,
  20. Union,
  21. )
  22. import grpc
  23. import torch
  24. from hivemind.proto import runtime_pb2
  25. from hivemind.utils.logging import get_logger
  26. from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
  27. logger = get_logger(__name__)
  28. Stub = TypeVar("Stub")
  29. GRPC_KEEPALIVE_OPTIONS = (
  30. ("grpc.keepalive_time_ms", 60 * 1000),
  31. ("grpc.keepalive_timeout_ms", 60 * 1000),
  32. ("grpc.keepalive_permit_without_calls", True),
  33. ("grpc.http2.max_pings_without_data", 0),
  34. ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
  35. ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
  36. )
  37. class ChannelInfo(NamedTuple):
  38. target: str
  39. aio: bool
  40. options: Tuple[Tuple[str, str], ...]
  41. credentials: Optional[grpc.ChannelCredentials]
  42. compression: Optional[grpc.Compression]
  43. class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
  44. """
  45. A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
  46. Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
  47. Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
  48. Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
  49. """
  50. MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
  51. EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
  52. logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
  53. _singleton: Optional[ChannelCache] = None
  54. _singleton_pid: int = os.getpid()
  55. _lock: threading.RLock = threading.RLock()
  56. _update_eviction_evt: threading.Event = threading.Event()
  57. def __init__(self, _created_as_singleton=False):
  58. assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
  59. super().__init__(maxsize=self.MAXIMUM_CHANNELS)
  60. self._is_active = True
  61. self._nearest_expiration_time = float("inf")
  62. self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
  63. self._eviction_thread.start()
  64. @classmethod
  65. def get_singleton(cls):
  66. """Get or create the channel cache for the current process"""
  67. with cls._lock:
  68. if cls._singleton is None or cls._singleton_pid != os.getpid():
  69. if cls._singleton is not None:
  70. cls._singleton._stop_background_thread()
  71. cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
  72. return cls._singleton
  73. @classmethod
  74. def get_stub(
  75. cls,
  76. target: str,
  77. stub_type: Type[Stub],
  78. *,
  79. aio: bool,
  80. options: Tuple[Tuple[str, Any]] = (),
  81. channel_credentials: Optional[grpc.ChannelCredentials] = None,
  82. compression: Optional[grpc.Compression] = None,
  83. ) -> Stub:
  84. """
  85. Create a grpc channel with given options or reuse pre-existing one
  86. :param target: the recipient's address and port
  87. :param stub_type: a gRPC stub (client) to be instantiated
  88. :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
  89. :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
  90. :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
  91. :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
  92. """
  93. cache = cls.get_singleton()
  94. with cls._lock:
  95. key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
  96. entry: ValueWithExpiration = super(cls, cache).get(key)
  97. if entry is not None:
  98. channel, stubs = entry.value
  99. else:
  100. channel = cls._create_channel(*key)
  101. stubs = {}
  102. channel._channel.check_connectivity_state(True)
  103. if stub_type not in stubs:
  104. stubs[stub_type] = stub_type(channel)
  105. # either cache channel or update expiration of an existing channel
  106. expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
  107. super(cls, cache).store(key, (channel, stubs), expiration_time)
  108. if expiration_time < cache._nearest_expiration_time:
  109. cache._nearest_expiration_time = expiration_time
  110. cls._update_eviction_evt.set()
  111. return stubs[stub_type]
  112. @classmethod
  113. def _create_channel(
  114. cls,
  115. target: str,
  116. aio: bool,
  117. extra_options: Tuple[Tuple[str, Any], ...],
  118. channel_credentials: Optional[grpc.ChannelCredentials],
  119. compression: Optional[grpc.Compression],
  120. ) -> Union[grpc.Channel, grpc.aio.Channel]:
  121. namespace = grpc.aio if aio else grpc
  122. options = extra_options + GRPC_KEEPALIVE_OPTIONS
  123. if channel_credentials is None:
  124. logger.debug(
  125. f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
  126. )
  127. return namespace.insecure_channel(target, options=options, compression=compression)
  128. else:
  129. logger.debug(
  130. f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
  131. f"options '{options}' and compression '{compression}'"
  132. )
  133. return namespace.secure_channel(
  134. target, credentials=channel_credentials, options=options, compression=compression
  135. )
  136. def _evict_stale_channels_in_background(self):
  137. while self._is_active:
  138. now = get_dht_time()
  139. time_to_wait = max(0.0, self._nearest_expiration_time - now)
  140. interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
  141. if interrupted_early:
  142. self._update_eviction_evt.clear()
  143. continue
  144. with self._lock:
  145. self._remove_outdated()
  146. _, entry = super().top()
  147. self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
  148. def _stop_background_thread(self):
  149. with self._lock:
  150. self._is_active = False
  151. self._update_eviction_evt.set()
  152. def store(self, *args, **kwargs) -> ValueError:
  153. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  154. def get(self, *args, **kwargs) -> ValueError:
  155. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  156. def top(self) -> ValueError:
  157. raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
  158. STREAMING_CHUNK_SIZE_BYTES = 2**16
  159. def split_for_streaming(
  160. serialized_tensor: runtime_pb2.Tensor,
  161. chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
  162. ) -> Iterator[runtime_pb2.Tensor]:
  163. """Split serialized_tensor into multiple chunks for gRPC streaming"""
  164. buffer = memoryview(serialized_tensor.buffer)
  165. num_chunks = len(range(0, len(buffer), chunk_size_bytes))
  166. yield runtime_pb2.Tensor(
  167. compression=serialized_tensor.compression,
  168. buffer=buffer[:chunk_size_bytes].tobytes(),
  169. chunks=num_chunks,
  170. size=serialized_tensor.size,
  171. dtype=serialized_tensor.dtype,
  172. requires_grad=serialized_tensor.requires_grad,
  173. )
  174. for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
  175. yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
  176. def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
  177. """Restore a result of split_into_chunks into a single serialized tensor"""
  178. stream = iter(stream)
  179. first_chunk = next(stream)
  180. serialized_tensor = runtime_pb2.Tensor()
  181. serialized_tensor.CopyFrom(first_chunk)
  182. buffer_chunks = [first_chunk.buffer]
  183. for tensor_part in stream:
  184. buffer_chunks.append(tensor_part.buffer)
  185. serialized_tensor.buffer = b"".join(buffer_chunks)
  186. return serialized_tensor
  187. StreamMessage = TypeVar("StreamMessage")
  188. async def gather_from_streaming(
  189. stream: AsyncIterator[StreamMessage],
  190. key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
  191. deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
  192. ) -> List[torch.Tensor]:
  193. tensors = []
  194. parts = []
  195. async for msg in stream:
  196. parts_stream = key(msg)
  197. for part in parts_stream:
  198. if part.dtype and parts:
  199. tensors.append(deserializer(combine_from_streaming(parts)))
  200. parts = []
  201. parts.append(part)
  202. if parts:
  203. tensors.append(deserializer(combine_from_streaming(parts)))
  204. return tensors