소스 검색

Process-wide channel cache for gRPC+aio (#120)

Changelist:
* Added process-wide ChannelCache for grpc channels and stubs
* switched dht/protocol.py, client/allreduce.py and client/expert.py to the new cache
* moved TimedStorage to utils
* fixed an edge case where TimedStorage with a fixed maxsize sometimes evicted too many entries (size was measured as len(self.expiration_heap), should be len(self.data))

Before this commit, each time we queried a remote dht peer with gRPC, we created a new channel.
In contrast, gRPC best practices recommend reusing channels for multiple rpc calls

hivemind/client/expert.py introduced channel caching, but it had a side-effect of keeping channels open forever

In this PR we implement a process-wide ChannelCache object that keeps track of open channels. The code is largely inspired by https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_simple_stubs.py , but with the added support for grpc.aio channels.


Co-authored-by: Vsevolod-pl <vsevolod-pl@yandex.ru>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 년 전
부모
커밋
1754792aad

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.13'
+__version__ = '0.8.14'

+ 6 - 8
hivemind/client/allreduce.py

@@ -11,7 +11,7 @@ import torch
 
 from hivemind.dht import DHTID, DHTExpiration
 from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
-from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor, ChannelCache
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 
 logger = get_logger(__name__)
@@ -149,10 +149,8 @@ class GroupAllReduce:
 
         return await self.averaged_part
 
-    def _get(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        """ TODO this function is deprecated and will be replaced by a shared channel cache """
-        channel = grpc.aio.insecure_channel(peer)
-        return averaging_pb2_grpc.DecentralizedAveragingStub(channel)
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
     async def handle_join_request(self, request: averaging_pb2.PeerInfo
                                   ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -220,7 +218,7 @@ class GroupAllReduce:
         assert self.state == ProtocolState.LOOKING_FOR_GROUP
         try:
             async with self.concurrent_requests_lock:
-                stream = self._get(leader).rpc_group_allreduce(self.info)
+                stream = self._get_peer_stub(leader).rpc_group_allreduce(self.info)
                 message = await stream.read()
                 logger.debug(f"{self} - requested {leader} to be my leader, received "
                              f"{averaging_pb2.MessageCode.Name(message.code)}")
@@ -259,7 +257,7 @@ class GroupAllReduce:
                 self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
             else:
                 serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-                response = await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
                     group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
 
                 if response.code == averaging_pb2.ACCEPTED:
@@ -279,7 +277,7 @@ class GroupAllReduce:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
 
             async def send_error_to_peer(peer_endpoint):
-                await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
                     group_id=self.group_id, endpoint=self.info.endpoint, code=code))
             for peer_endpoint in ordered_group_endpoints:
                 asyncio.create_task(send_error_to_peer(peer_endpoint))

+ 3 - 8
hivemind/client/expert.py

@@ -1,26 +1,21 @@
 import pickle
-from functools import lru_cache
 from typing import Tuple, Optional, Any, Dict
 
-import grpc
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-@lru_cache(maxsize=None)
 def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
     """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
-    channel_options = [
-        ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)
-    ] + list(extra_options)
-    return runtime_grpc.ConnectionHandlerStub(grpc.insecure_channel(endpoint, options=channel_options))
+    channel_options = (('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)) + extra_options
+    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
 
 
 class RemoteExpert(nn.Module):

+ 1 - 1
hivemind/dht/__init__.py

@@ -26,7 +26,7 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.dht.storage import ValueWithExpiration
+from hivemind.utils.timed_storage import ValueWithExpiration
 from hivemind.utils import MPFuture, Endpoint, get_logger
 
 logger = get_logger(__name__)

+ 2 - 2
hivemind/dht/node.py

@@ -11,9 +11,9 @@ from sortedcontainers import SortedList
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue, ValueWithExpiration
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, ValueWithExpiration
 
 logger = get_logger(__name__)
 

+ 7 - 8
hivemind/dht/protocol.py

@@ -7,9 +7,9 @@ from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 import grpc
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
-from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
+from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 
 logger = get_logger(__name__)
 
@@ -78,10 +78,9 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
-    def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
+    def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
-        channel = grpc.aio.insecure_channel(peer, options=self.channel_options)
-        return dht_grpc.DHTStub(channel)
+        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
 
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
         """
@@ -93,7 +92,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         try:
             async with self.rpc_semaphore:
-                peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+                peer_info = await self._get_dht_stub(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
@@ -155,7 +154,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
             async with self.rpc_semaphore:
-                response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+                response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
@@ -203,7 +202,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         try:
             async with self.rpc_semaphore:
-                response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+                response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))

+ 2 - 5
hivemind/dht/routing.py

@@ -5,15 +5,12 @@ import hashlib
 import heapq
 import os
 import random
-import time
 from collections.abc import Iterable
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from hivemind.utils import Endpoint, PickleSerializer, get_dht_time, DHTExpiration
 
-from hivemind.utils import Endpoint, PickleSerializer
-
-DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, float, bytes, bytes
-get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
+DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 
 
 class RoutingTable:

+ 3 - 105
hivemind/dht/storage.py

@@ -1,111 +1,9 @@
 from __future__ import annotations
-import heapq
-from contextlib import contextmanager
-from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, NamedTuple
+from typing import Optional, Union
 
-from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTExpiration, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
-
-KeyType = TypeVar('KeyType')
-ValueType = TypeVar('ValueType')
-ROOT = 0
-
-
-class ValueWithExpiration(NamedTuple, Generic[ValueType]):
-    value: ValueType
-    expiration_time: DHTExpiration
-
-
-class HeapEntry(NamedTuple, Generic[KeyType]):
-    expiration_time: DHTExpiration
-    key: KeyType
-
-
-class TimedStorage(Generic[KeyType, ValueType]):
-    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
-    frozen = False  # can be set to True. If true, do not remove outdated elements
-
-    def __init__(self, maxsize: Optional[int] = None):
-        self.maxsize = maxsize or float("inf")
-        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
-        self.expiration_heap: List[HeapEntry[KeyType]] = []
-        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
-
-    def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
-                                                            or len(self.expiration_heap) > self.maxsize):
-            heap_entry = heapq.heappop(self.expiration_heap)
-            if self.key_to_heap.get(heap_entry.key) == heap_entry:
-                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
-
-    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
-        """
-        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
-        :returns: True if new value was stored, False it was rejected (current value is newer)
-        """
-        if expiration_time < get_dht_time() and not self.frozen:
-            return False
-        self.key_to_heap[key] = HeapEntry(expiration_time, key)
-        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
-        if key in self.data:
-            if self.data[key].expiration_time < expiration_time:
-                self.data[key] = ValueWithExpiration(value, expiration_time)
-                return True
-            return False
-        self.data[key] = ValueWithExpiration(value, expiration_time)
-        self._remove_outdated()
-        return True
-
-    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
-        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
-        self._remove_outdated()
-        if key in self.data:
-            return self.data[key]
-        return None
-
-    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
-        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self._remove_outdated()
-        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
-
-    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
-        """ Return the entry with earliest expiration or None if there isn't any """
-        self._remove_outdated()
-        if self.data:
-            # skip leftover "ghost" entries until first real entry
-            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
-                heapq.heappop(self.expiration_heap)
-            top_key = self.expiration_heap[ROOT].key
-            return top_key, self.data[top_key]
-        return None, None
-
-    def __contains__(self, key: KeyType):
-        self._remove_outdated()
-        return key in self.data
-
-    def __len__(self):
-        self._remove_outdated()
-        return len(self.data)
-
-    def __delitem__(self, key: KeyType):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
-
-    def __bool__(self):
-        return bool(self.data)
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.data})"
-
-    @contextmanager
-    def freeze(self):
-        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
-        prev_frozen, self.frozen = self.frozen, True
-        try:
-            yield self
-        finally:
-            self.frozen = prev_frozen
+from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage
 
 
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,4 +5,5 @@ from hivemind.utils.serializer import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
 from hivemind.utils.grpc import *
+from hivemind.utils.timed_storage import *
 from hivemind.utils.logging import get_logger

+ 132 - 1
hivemind/utils/grpc.py

@@ -1,17 +1,148 @@
 """
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
+from __future__ import annotations
+import os
+import threading
+from typing import NamedTuple, Sequence, Tuple, Optional, Union, Any, Dict, TypeVar, Type
 
+import grpc
 import numpy as np
 import torch
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, DHTExpiration, ValueWithExpiration
+from hivemind.utils.networking import Endpoint
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__file__)
+
+Stub = TypeVar("Stub")
+
+
+class ChannelInfo(NamedTuple):
+    target: Endpoint
+    aio: bool
+    options: Tuple[Tuple[str, str], ...]
+    credentials: Optional[grpc.ChannelCredentials]
+    compression: Optional[grpc.Compression]
+
+
+class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
+    """
+    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
+    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
+    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
+    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
+    """
+    MAXIMUM_CHANNELS = os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096)
+    EVICTION_PERIOD_SECONDS = os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60)
+    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
+
+    _singleton: Optional[ChannelCache] = None
+    _singleton_pid: int = os.getpid()
+    _lock: threading.RLock = threading.RLock()
+    _update_eviction_evt: threading.Event = threading.Event()
+
+    def __init__(self, _created_as_singleton=False):
+        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
+        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
+        self._is_active = True
+        self._nearest_expiration_time = float('inf')
+        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
+        self._eviction_thread.start()
+
+    @classmethod
+    def get_singleton(cls):
+        """ Get or create the channel cache for the current process """
+        with cls._lock:
+            if cls._singleton is None or cls._singleton_pid != os.getpid():
+                if cls._singleton is not None:
+                    cls._singleton._stop_background_thread()
+                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
+            return cls._singleton
+
+    @classmethod
+    def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Sequence[Tuple[str, Any]] = (),
+                 channel_credentials: Optional[grpc.ChannelCredentials] = None,
+                 compression: Optional[grpc.Compression] = None) -> Stub:
+        """
+        Create a grpc channel with given options or reuse pre-existing one
+
+        :param target: the recipient's address and port
+        :param stub_type: a gRPC stub (client) to be instantiated
+        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
+        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
+        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
+        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
+        """
+        cache = cls.get_singleton()
+        with cls._lock:
+            key = ChannelInfo(target, aio, tuple(options or ()), channel_credentials, compression)
+            entry: ValueWithExpiration = super(cls, cache).get(key)
+            channel, stubs = entry.value if entry is not None else (cls._create_channel(*key), {})
+            if stub_type not in stubs:
+                stubs[stub_type] = stub_type(channel)
+
+            # either cache channel or update expiration of an existing channel
+            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
+            super(cls, cache).store(key, (channel, stubs), expiration_time)
+
+            if expiration_time < cache._nearest_expiration_time:
+                cache._nearest_expiration_time = expiration_time
+                cls._update_eviction_evt.set()
+
+            return stubs[stub_type]
+
+    @classmethod
+    def _create_channel(cls, target: Endpoint, aio: bool, options: Sequence[Tuple[str, Any], ...],
+                        channel_credentials: Optional[grpc.ChannelCredentials],
+                        compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
+        namespace = grpc.aio if aio else grpc
+        if channel_credentials is None:
+            logger.debug(f"Creating insecure {namespace} channel with options '{options}' "
+                         f"and compression '{compression}'")
+            return namespace.insecure_channel(target, options=options, compression=compression)
+        else:
+            logger.debug(f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
+                         f"options '{options}' and compression '{compression}'")
+            return namespace.secure_channel(target, credentials=channel_credentials,
+                                            options=options, compression=compression)
+
+    def _evict_stale_channels_in_background(self):
+        while self._is_active:
+            now = get_dht_time()
+            time_to_wait = max(0.0, self._nearest_expiration_time - now)
+            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float('inf') else None)
+            if interrupted_early:
+                self._update_eviction_evt.clear()
+                continue
+
+            with self._lock:
+                self._remove_outdated()
+                _, entry = super().top()
+                self._nearest_expiration_time = entry.expiration_time if entry is not None else float('inf')
+
+    def _stop_background_thread(self):
+        with self._lock:
+            self._is_active = False
+            self._update_eviction_evt.set()
+
+    def store(self, *args, **kwargs) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
+    def get(self, *args, **kwargs) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
+    def top(self) -> ValueError:
+        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
+
 
 FP16_MAX = 65_504
 
 
-def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, 
+def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
                            allow_inplace=False) -> runtime_pb2.Tensor:
     if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
         assert tensor.dtype == torch.float32

+ 109 - 0
hivemind/utils/timed_storage.py

@@ -0,0 +1,109 @@
+""" A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
+from __future__ import annotations
+import heapq
+import time
+from contextlib import contextmanager
+from typing import TypeVar, NamedTuple, Generic, Optional, Dict, List, Iterator, Tuple
+
+KeyType = TypeVar('KeyType')
+ValueType = TypeVar('ValueType')
+get_dht_time = time.time  # a global (weakly synchronized) time
+DHTExpiration = float
+ROOT = 0
+
+
+class ValueWithExpiration(NamedTuple, Generic[ValueType]):
+    value: ValueType
+    expiration_time: DHTExpiration
+
+
+class HeapEntry(NamedTuple, Generic[KeyType]):
+    expiration_time: DHTExpiration
+    key: KeyType
+
+
+class TimedStorage(Generic[KeyType, ValueType]):
+    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
+    frozen = False  # can be set to True. If true, do not remove outdated elements
+
+    def __init__(self, maxsize: Optional[int] = None):
+        self.maxsize = maxsize or float("inf")
+        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
+        self.expiration_heap: List[HeapEntry[KeyType]] = []
+        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
+                                                            or len(self.data) > self.maxsize):
+            heap_entry = heapq.heappop(self.expiration_heap)
+            if self.key_to_heap.get(heap_entry.key) == heap_entry:
+                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if expiration_time < get_dht_time() and not self.frozen:
+            return False
+        self.key_to_heap[key] = HeapEntry(expiration_time, key)
+        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
+        if key in self.data:
+            if self.data[key].expiration_time < expiration_time:
+                self.data[key] = ValueWithExpiration(value, expiration_time)
+                return True
+            return False
+        self.data[key] = ValueWithExpiration(value, expiration_time)
+        self._remove_outdated()
+        return True
+
+    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
+        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
+        self._remove_outdated()
+        if key in self.data:
+            return self.data[key]
+        return None
+
+    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self._remove_outdated()
+        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
+
+    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            # skip leftover "ghost" entries until first real entry
+            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
+                heapq.heappop(self.expiration_heap)
+            top_key = self.expiration_heap[ROOT].key
+            return top_key, self.data[top_key]
+        return None, None
+
+    def __contains__(self, key: KeyType):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: KeyType):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.data})"
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen

+ 1 - 0
requirements.txt

@@ -7,4 +7,5 @@ sortedcontainers
 uvloop>=0.14.0
 grpcio>=1.33.2
 grpcio-tools>=1.33.2
+protobuf>=3.12.2
 configargparse>=1.2.3

+ 2 - 1
tests/test_averaging.py

@@ -91,7 +91,8 @@ async def test_allreduce_protocol():
     ]
 
     assert len(averaged_tensors) == len(reference_tensors)
-    assert all(map(torch.allclose, averaged_tensors, reference_tensors))
+    assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+               for our, ref in zip(averaged_tensors, reference_tensors))
 
 
 @pytest.mark.forked

+ 6 - 3
tests/test_dht_storage.py

@@ -37,11 +37,14 @@ def test_change_expiration_time():
 
 
 def test_maxsize_cache():
-    d = DHTLocalStorage(maxsize=1)
-    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d = DHTLocalStorage(maxsize=2)
+    d.store(DHTID.generate("key1a"), b"val1a", get_dht_time() + 1)
+    d.store(DHTID.generate("key1b"), b"val1b", get_dht_time() + 1)
+    d.store(DHTID.generate("key1a"), b"val1a2", get_dht_time() + 2)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1")) is None, "Value with less exp time, must be deleted"
+    assert d.get(DHTID.generate("key1a"))[0] == b"val1a2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1b")) is None, "Value with less exp time, must be deleted"
 
 
 def test_localstorage_top():

+ 4 - 6
tests/test_moe.py

@@ -30,7 +30,6 @@ def test_call_many():
     backward_k_min = 1
     forward_timeout = None
     backward_timeout = None
-    rtol = 1e-3
     atol = 1e-5
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
@@ -61,7 +60,7 @@ def test_call_many():
         reference_outputs[2, 0] = e1(inputs_clone[2:3])
         reference_outputs[2, 2] = e3(inputs_clone[2:3])
 
-        assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
+        assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
         proj = torch.randn(4, 64)
         loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         loss.backward()
@@ -70,7 +69,7 @@ def test_call_many():
         reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
         reference_loss.backward()
         reference_grad = inputs_clone.grad.data.cpu().clone()
-        assert torch.allclose(our_grad, reference_grad, rtol, atol)
+        assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
 
 
 @pytest.mark.forked
@@ -125,7 +124,6 @@ def test_beam_search_correctness():
 
 @pytest.mark.forked
 def test_determinism():
-    rtol = 0
     atol = 1e-5
 
     xx = torch.randn(32, 1024, requires_grad=True)
@@ -141,8 +139,8 @@ def test_determinism():
         grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
         grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
 
-    assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
-    assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
+    assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
+    assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
 
 
 @pytest.mark.forked

+ 38 - 1
tests/test_util_modules.py

@@ -3,7 +3,8 @@ import torch
 
 import pytest
 import hivemind
-
+from hivemind.proto.dht_pb2_grpc import DHTStub
+from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from concurrent.futures import CancelledError
 
 
@@ -129,3 +130,39 @@ def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
     assert error.square().mean() < alpha
 
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_channel_cache():
+    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
+    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
+
+    c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+    c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
+    c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
+    c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
+    c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+    c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
+    c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
+    c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+
+    await asyncio.sleep(0.2)
+    c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
+    c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
+    c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
+    await asyncio.sleep(0.05)
+    c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
+                                                        stub_type=ConnectionHandlerStub)
+    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
+
+    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
+    assert isinstance(all_channels[-1], ConnectionHandlerStub)
+    assert 'aio' in repr(c2.rpc_find)
+    assert 'aio' not in repr(c1.rpc_find)
+
+    duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
+                  (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
+    for i in range(len(all_channels)):
+        for j in range(i + 1, len(all_channels)):
+            ci, cj = all_channels[i], all_channels[j]
+            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)