Browse Source

Process-wide channel cache for gRPC+aio (#120)

Changelist:
* Added process-wide ChannelCache for grpc channels and stubs
* switched dht/protocol.py, client/allreduce.py and client/expert.py to the new cache
* moved TimedStorage to utils
* fixed an edge case where TimedStorage with a fixed maxsize sometimes evicted too many entries (size was measured as len(self.expiration_heap), should be len(self.data))

Before this commit, each time we queried a remote dht peer with gRPC, we created a new channel.
In contrast, gRPC best practices recommend reusing channels for multiple rpc calls

hivemind/client/expert.py introduced channel caching, but it had a side-effect of keeping channels open forever

In this PR we implement a process-wide ChannelCache object that keeps track of open channels. The code is largely inspired by https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_simple_stubs.py , but with the added support for grpc.aio channels.


Co-authored-by: Vsevolod-pl <vsevolod-pl@yandex.ru>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 years ago
parent
commit
1754792aad

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.13'
+__version__ = '0.8.14'

+ 6 - 8
hivemind/client/allreduce.py

@@ -11,7 +11,7 @@ import torch
 
 
 from hivemind.dht import DHTID, DHTExpiration
 from hivemind.dht import DHTID, DHTExpiration
 from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
 from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
-from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor, ChannelCache
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -149,10 +149,8 @@ class GroupAllReduce:
 
 
         return await self.averaged_part
         return await self.averaged_part
 
 
-    def _get(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        """ TODO this function is deprecated and will be replaced by a shared channel cache """
-        channel = grpc.aio.insecure_channel(peer)
-        return averaging_pb2_grpc.DecentralizedAveragingStub(channel)
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
 
     async def handle_join_request(self, request: averaging_pb2.PeerInfo
     async def handle_join_request(self, request: averaging_pb2.PeerInfo
                                   ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
                                   ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -220,7 +218,7 @@ class GroupAllReduce:
         assert self.state == ProtocolState.LOOKING_FOR_GROUP
         assert self.state == ProtocolState.LOOKING_FOR_GROUP
         try:
         try:
             async with self.concurrent_requests_lock:
             async with self.concurrent_requests_lock:
-                stream = self._get(leader).rpc_group_allreduce(self.info)
+                stream = self._get_peer_stub(leader).rpc_group_allreduce(self.info)
                 message = await stream.read()
                 message = await stream.read()
                 logger.debug(f"{self} - requested {leader} to be my leader, received "
                 logger.debug(f"{self} - requested {leader} to be my leader, received "
                              f"{averaging_pb2.MessageCode.Name(message.code)}")
                              f"{averaging_pb2.MessageCode.Name(message.code)}")
@@ -259,7 +257,7 @@ class GroupAllReduce:
                 self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
                 self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
             else:
             else:
                 serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
                 serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-                response = await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
                     group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
                     group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
 
 
                 if response.code == averaging_pb2.ACCEPTED:
                 if response.code == averaging_pb2.ACCEPTED:
@@ -279,7 +277,7 @@ class GroupAllReduce:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
 
 
             async def send_error_to_peer(peer_endpoint):
             async def send_error_to_peer(peer_endpoint):
-                await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
                     group_id=self.group_id, endpoint=self.info.endpoint, code=code))
                     group_id=self.group_id, endpoint=self.info.endpoint, code=code))
             for peer_endpoint in ordered_group_endpoints:
             for peer_endpoint in ordered_group_endpoints:
                 asyncio.create_task(send_error_to_peer(peer_endpoint))
                 asyncio.create_task(send_error_to_peer(peer_endpoint))

+ 3 - 8
hivemind/client/expert.py

@@ -1,26 +1,21 @@
 import pickle
 import pickle
-from functools import lru_cache
 from typing import Tuple, Optional, Any, Dict
 from typing import Tuple, Optional, Any, Dict
 
 
-import grpc
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-@lru_cache(maxsize=None)
 def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
 def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
     """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
     """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
-    channel_options = [
-        ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)
-    ] + list(extra_options)
-    return runtime_grpc.ConnectionHandlerStub(grpc.insecure_channel(endpoint, options=channel_options))
+    channel_options = (('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)) + extra_options
+    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):

+ 1 - 1
hivemind/dht/__init__.py

@@ -26,7 +26,7 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.dht.storage import ValueWithExpiration
+from hivemind.utils.timed_storage import ValueWithExpiration
 from hivemind.utils import MPFuture, Endpoint, get_logger
 from hivemind.utils import MPFuture, Endpoint, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 2 - 2
hivemind/dht/node.py

@@ -11,9 +11,9 @@ from sortedcontainers import SortedList
 
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue, ValueWithExpiration
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 7 - 8
hivemind/dht/protocol.py

@@ -7,9 +7,9 @@ from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 import grpc
 import grpc
 
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
-from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
+from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -78,10 +78,9 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:
         else:
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
 
-    def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
+    def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
         """ get a DHTStub that sends requests to a given peer """
-        channel = grpc.aio.insecure_channel(peer, options=self.channel_options)
-        return dht_grpc.DHTStub(channel)
+        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
 
 
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
         """
         """
@@ -93,7 +92,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         """
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
-                peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+                peer_info = await self._get_dht_stub(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.aio.AioRpcError as error:
         except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
             peer_info = None
@@ -155,7 +154,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
-                response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+                response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
@@ -203,7 +202,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
-                response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+                response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))

+ 2 - 5
hivemind/dht/routing.py

@@ -5,15 +5,12 @@ import hashlib
 import heapq
 import heapq
 import os
 import os
 import random
 import random
-import time
 from collections.abc import Iterable
 from collections.abc import Iterable
 from itertools import chain
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from hivemind.utils import Endpoint, PickleSerializer, get_dht_time, DHTExpiration
 
 
-from hivemind.utils import Endpoint, PickleSerializer
-
-DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, float, bytes, bytes
-get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
+DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 
 
 
 
 class RoutingTable:
 class RoutingTable:

+ 3 - 105
hivemind/dht/storage.py

@@ -1,111 +1,9 @@
 from __future__ import annotations
 from __future__ import annotations
-import heapq
-from contextlib import contextmanager
-from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, NamedTuple
+from typing import Optional, Union
 
 
-from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTExpiration, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
-
-KeyType = TypeVar('KeyType')
-ValueType = TypeVar('ValueType')
-ROOT = 0
-
-
-class ValueWithExpiration(NamedTuple, Generic[ValueType]):
-    value: ValueType
-    expiration_time: DHTExpiration
-
-
-class HeapEntry(NamedTuple, Generic[KeyType]):
-    expiration_time: DHTExpiration
-    key: KeyType
-
-
-class TimedStorage(Generic[KeyType, ValueType]):
-    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
-    frozen = False  # can be set to True. If true, do not remove outdated elements
-
-    def __init__(self, maxsize: Optional[int] = None):
-        self.maxsize = maxsize or float("inf")
-        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
-        self.expiration_heap: List[HeapEntry[KeyType]] = []
-        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
-
-    def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
-                                                            or len(self.expiration_heap) > self.maxsize):
-            heap_entry = heapq.heappop(self.expiration_heap)
-            if self.key_to_heap.get(heap_entry.key) == heap_entry:
-                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
-
-    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
-        """
-        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
-        :returns: True if new value was stored, False it was rejected (current value is newer)
-        """
-        if expiration_time < get_dht_time() and not self.frozen:
-            return False
-        self.key_to_heap[key] = HeapEntry(expiration_time, key)
-        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
-        if key in self.data:
-            if self.data[key].expiration_time < expiration_time:
-                self.data[key] = ValueWithExpiration(value, expiration_time)
-                return True
-            return False
-        self.data[key] = ValueWithExpiration(value, expiration_time)
-        self._remove_outdated()
-        return True
-
-    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
-        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
-        self._remove_outdated()
-        if key in self.data:
-            return self.data[key]
-        return None
-
-    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
-        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self._remove_outdated()
-        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
-
-    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
-        """ Return the entry with earliest expiration or None if there isn't any """
-        self._remove_outdated()
-        if self.data:
-            # skip leftover "ghost" entries until first real entry
-            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
-                heapq.heappop(self.expiration_heap)
-            top_key = self.expiration_heap[ROOT].key
-            return top_key, self.data[top_key]
-        return None, None
-
-    def __contains__(self, key: KeyType):
-        self._remove_outdated()
-        return key in self.data
-
-    def __len__(self):
-        self._remove_outdated()
-        return len(self.data)
-
-    def __delitem__(self, key: KeyType):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
-
-    def __bool__(self):
-        return bool(self.data)
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.data})"
-
-    @contextmanager
-    def freeze(self):
-        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
-        prev_frozen, self.frozen = self.frozen, True
-        try:
-            yield self
-        finally:
-            self.frozen = prev_frozen
+from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage
 
 
 
 
 @MSGPackSerializer.ext_serializable(0x50)
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,4 +5,5 @@ from hivemind.utils.serializer import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
 from hivemind.utils.threading import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
+from hivemind.utils.timed_storage import *
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger

+ 132 - 1
hivemind/utils/grpc.py

@@ -1,17 +1,148 @@
 """
 """
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
 """
+from __future__ import annotations
+import os
+import threading
+from typing import NamedTuple, Sequence, Tuple, Optional, Union, Any, Dict, TypeVar, Type
 
 
+import grpc
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, DHTExpiration, ValueWithExpiration
+from hivemind.utils.networking import Endpoint
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__file__)
+
+Stub = TypeVar("Stub")
+
+
+class ChannelInfo(NamedTuple):
+    target: Endpoint
+    aio: bool
+    options: Tuple[Tuple[str, str], ...]
+    credentials: Optional[grpc.ChannelCredentials]
+    compression: Optional[grpc.Compression]
+
+
+class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
+    """
+    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
+    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
+    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
+    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
+    """
+    MAXIMUM_CHANNELS = os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096)
+    EVICTION_PERIOD_SECONDS = os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60)
+    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
+
+    _singleton: Optional[ChannelCache] = None
+    _singleton_pid: int = os.getpid()
+    _lock: threading.RLock = threading.RLock()
+    _update_eviction_evt: threading.Event = threading.Event()
+
+    def __init__(self, _created_as_singleton=False):
+        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
+        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
+        self._is_active = True
+        self._nearest_expiration_time = float('inf')
+        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
+        self._eviction_thread.start()
+
+    @classmethod
+    def get_singleton(cls):
+        """ Get or create the channel cache for the current process """
+        with cls._lock:
+            if cls._singleton is None or cls._singleton_pid != os.getpid():
+                if cls._singleton is not None:
+                    cls._singleton._stop_background_thread()
+                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
+            return cls._singleton
+
+    @classmethod
+    def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Sequence[Tuple[str, Any]] = (),
+                 channel_credentials: Optional[grpc.ChannelCredentials] = None,
+                 compression: Optional[grpc.Compression] = None) -> Stub:
+        """
+        Create a grpc channel with given options or reuse pre-existing one
+
+        :param target: the recipient's address and port
+        :param stub_type: a gRPC stub (client) to be instantiated
+        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
+        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
+        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
+        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
+        """
+        cache = cls.get_singleton()
+        with cls._lock:
+            key = ChannelInfo(target, aio, tuple(options or ()), channel_credentials, compression)
+            entry: ValueWithExpiration = super(cls, cache).get(key)
+            channel, stubs = entry.value if entry is not None else (cls._create_channel(*key), {})
+            if stub_type not in stubs:
+                stubs[stub_type] = stub_type(channel)
+
+            # either cache channel or update expiration of an existing channel
+            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
+            super(cls, cache).store(key, (channel, stubs), expiration_time)
+
+            if expiration_time < cache._nearest_expiration_time:
+                cache._nearest_expiration_time = expiration_time
+                cls._update_eviction_evt.set()
+
+            return stubs[stub_type]
+
+    @classmethod
+    def _create_channel(cls, target: Endpoint, aio: bool, options: Sequence[Tuple[str, Any], ...],
+                        channel_credentials: Optional[grpc.ChannelCredentials],
+                        compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
+        namespace = grpc.aio if aio else grpc
+        if channel_credentials is None:
+            logger.debug(f"Creating insecure {namespace} channel with options '{options}' "
+                         f"and compression '{compression}'")
+            return namespace.insecure_channel(target, options=options, compression=compression)
+        else:
+            logger.debug(f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
+                         f"options '{options}' and compression '{compression}'")
+            return namespace.secure_channel(target, credentials=channel_credentials,
+                                            options=options, compression=compression)
+
+    def _evict_stale_channels_in_background(self):
+        while self._is_active:
+            now = get_dht_time()
+            time_to_wait = max(0.0, self._nearest_expiration_time - now)
+            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float('inf') else None)
+            if interrupted_early:
+                self._update_eviction_evt.clear()
+                continue
+
+            with self._lock:
+                self._remove_outdated()
+                _, entry = super().top()
+                self._nearest_expiration_time = entry.expiration_time if entry is not None else float('inf')
+
+    def _stop_background_thread(self):
+        with self._lock:
+            self._is_active = False
+            self._update_eviction_evt.set()
+
+    def store(self, *args, **kwargs) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
+    def get(self, *args, **kwargs) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
+    def top(self) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
 
 
 FP16_MAX = 65_504
 FP16_MAX = 65_504
 
 
 
 
-def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, 
+def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
                            allow_inplace=False) -> runtime_pb2.Tensor:
                            allow_inplace=False) -> runtime_pb2.Tensor:
     if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
     if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
         assert tensor.dtype == torch.float32
         assert tensor.dtype == torch.float32

+ 109 - 0
hivemind/utils/timed_storage.py

@@ -0,0 +1,109 @@
+""" A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
+from __future__ import annotations
+import heapq
+import time
+from contextlib import contextmanager
+from typing import TypeVar, NamedTuple, Generic, Optional, Dict, List, Iterator, Tuple
+
+KeyType = TypeVar('KeyType')
+ValueType = TypeVar('ValueType')
+get_dht_time = time.time  # a global (weakly synchronized) time
+DHTExpiration = float
+ROOT = 0
+
+
+class ValueWithExpiration(NamedTuple, Generic[ValueType]):
+    value: ValueType
+    expiration_time: DHTExpiration
+
+
+class HeapEntry(NamedTuple, Generic[KeyType]):
+    expiration_time: DHTExpiration
+    key: KeyType
+
+
+class TimedStorage(Generic[KeyType, ValueType]):
+    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
+    frozen = False  # can be set to True. If true, do not remove outdated elements
+
+    def __init__(self, maxsize: Optional[int] = None):
+        self.maxsize = maxsize or float("inf")
+        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
+        self.expiration_heap: List[HeapEntry[KeyType]] = []
+        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
+                                                            or len(self.data) > self.maxsize):
+            heap_entry = heapq.heappop(self.expiration_heap)
+            if self.key_to_heap.get(heap_entry.key) == heap_entry:
+                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if expiration_time < get_dht_time() and not self.frozen:
+            return False
+        self.key_to_heap[key] = HeapEntry(expiration_time, key)
+        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
+        if key in self.data:
+            if self.data[key].expiration_time < expiration_time:
+                self.data[key] = ValueWithExpiration(value, expiration_time)
+                return True
+            return False
+        self.data[key] = ValueWithExpiration(value, expiration_time)
+        self._remove_outdated()
+        return True
+
+    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
+        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
+        self._remove_outdated()
+        if key in self.data:
+            return self.data[key]
+        return None
+
+    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self._remove_outdated()
+        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
+
+    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            # skip leftover "ghost" entries until first real entry
+            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
+                heapq.heappop(self.expiration_heap)
+            top_key = self.expiration_heap[ROOT].key
+            return top_key, self.data[top_key]
+        return None, None
+
+    def __contains__(self, key: KeyType):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: KeyType):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.data})"
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen

+ 1 - 0
requirements.txt

@@ -7,4 +7,5 @@ sortedcontainers
 uvloop>=0.14.0
 uvloop>=0.14.0
 grpcio>=1.33.2
 grpcio>=1.33.2
 grpcio-tools>=1.33.2
 grpcio-tools>=1.33.2
+protobuf>=3.12.2
 configargparse>=1.2.3
 configargparse>=1.2.3

+ 2 - 1
tests/test_averaging.py

@@ -91,7 +91,8 @@ async def test_allreduce_protocol():
     ]
     ]
 
 
     assert len(averaged_tensors) == len(reference_tensors)
     assert len(averaged_tensors) == len(reference_tensors)
-    assert all(map(torch.allclose, averaged_tensors, reference_tensors))
+    assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+               for our, ref in zip(averaged_tensors, reference_tensors))
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 6 - 3
tests/test_dht_storage.py

@@ -37,11 +37,14 @@ def test_change_expiration_time():
 
 
 
 
 def test_maxsize_cache():
 def test_maxsize_cache():
-    d = DHTLocalStorage(maxsize=1)
-    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d = DHTLocalStorage(maxsize=2)
+    d.store(DHTID.generate("key1a"), b"val1a", get_dht_time() + 1)
+    d.store(DHTID.generate("key1b"), b"val1b", get_dht_time() + 1)
+    d.store(DHTID.generate("key1a"), b"val1a2", get_dht_time() + 2)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1")) is None, "Value with less exp time, must be deleted"
+    assert d.get(DHTID.generate("key1a"))[0] == b"val1a2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1b")) is None, "Value with less exp time, must be deleted"
 
 
 
 
 def test_localstorage_top():
 def test_localstorage_top():

+ 4 - 6
tests/test_moe.py

@@ -30,7 +30,6 @@ def test_call_many():
     backward_k_min = 1
     backward_k_min = 1
     forward_timeout = None
     forward_timeout = None
     backward_timeout = None
     backward_timeout = None
-    rtol = 1e-3
     atol = 1e-5
     atol = 1e-5
 
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
@@ -61,7 +60,7 @@ def test_call_many():
         reference_outputs[2, 0] = e1(inputs_clone[2:3])
         reference_outputs[2, 0] = e1(inputs_clone[2:3])
         reference_outputs[2, 2] = e3(inputs_clone[2:3])
         reference_outputs[2, 2] = e3(inputs_clone[2:3])
 
 
-        assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
+        assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
         proj = torch.randn(4, 64)
         proj = torch.randn(4, 64)
         loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         loss.backward()
         loss.backward()
@@ -70,7 +69,7 @@ def test_call_many():
         reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         reference_loss.backward()
         reference_loss.backward()
         reference_grad = inputs_clone.grad.data.cpu().clone()
         reference_grad = inputs_clone.grad.data.cpu().clone()
-        assert torch.allclose(our_grad, reference_grad, rtol, atol)
+        assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -125,7 +124,6 @@ def test_beam_search_correctness():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_determinism():
 def test_determinism():
-    rtol = 0
     atol = 1e-5
     atol = 1e-5
 
 
     xx = torch.randn(32, 1024, requires_grad=True)
     xx = torch.randn(32, 1024, requires_grad=True)
@@ -141,8 +139,8 @@ def test_determinism():
         grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
         grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
         grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
         grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
 
 
-    assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
-    assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
+    assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
+    assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 38 - 1
tests/test_util_modules.py

@@ -3,7 +3,8 @@ import torch
 
 
 import pytest
 import pytest
 import hivemind
 import hivemind
-
+from hivemind.proto.dht_pb2_grpc import DHTStub
+from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from concurrent.futures import CancelledError
 from concurrent.futures import CancelledError
 
 
 
 
@@ -129,3 +130,39 @@ def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
     assert error.square().mean() < alpha
     assert error.square().mean() < alpha
 
 
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_channel_cache():
+    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
+    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
+
+    c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+    c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
+    c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
+    c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
+    c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+    c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
+    c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
+    c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+
+    await asyncio.sleep(0.2)
+    c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
+    c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
+    c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
+    await asyncio.sleep(0.05)
+    c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
+                                                        stub_type=ConnectionHandlerStub)
+    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
+
+    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
+    assert isinstance(all_channels[-1], ConnectionHandlerStub)
+    assert 'aio' in repr(c2.rpc_find)
+    assert 'aio' not in repr(c1.rpc_find)
+
+    duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
+                  (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
+    for i in range(len(all_channels)):
+        for j in range(i + 1, len(all_channels)):
+            ci, cj = all_channels[i], all_channels[j]
+            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)