Parcourir la source

DHT caching and TransformerEncoderLayer fixes (#124)

* Remove transpose_ from TransformerEncoderLayer

* Fix cache refreshing
Max Ryabinin il y a 4 ans
Parent
commit
34585bafac

+ 0 - 1
hivemind/dht/__init__.py

@@ -26,7 +26,6 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils.timed_storage import ValueWithExpiration
 from hivemind.utils import MPFuture, Endpoint, get_logger
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/dht/node.py

@@ -513,7 +513,7 @@ class DHTNode:
                 # step 2: find all keys that we should already refresh and remove them from queue
                 current_time = get_dht_time()
                 keys_to_refresh = {key_id}
-                max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
+                max_expiration_time = nearest_refresh_time
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                 while self.cache_refresh_queue and len(keys_to_refresh) < self.chunk_size:
                     key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()

+ 5 - 4
hivemind/dht/protocol.py

@@ -59,7 +59,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             self.port = found_port
             await self.server.start()
         else:  # not listening to incoming requests, client-only mode
-            # note: use empty node_info so peers wont add you to their routing tables
+            # note: use empty node_info so peers won't add you to their routing tables
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
             if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                 logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
@@ -186,8 +186,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
 
-    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[
-            Dict[DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
+    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[
+        DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
          k additional peers that are most likely to have this key (ranked by XOR distance)
@@ -238,7 +238,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_item = self.storage.get(key_id)
             cached_item = self.cache.get(key_id)
-            if cached_item is not None and (maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time):
+            if cached_item is not None and (maybe_item is None
+                                            or cached_item.expiration_time > maybe_item.expiration_time):
                 maybe_item = cached_item
 
             if maybe_item is None:  # value not found

+ 0 - 2
hivemind/server/layers/common.py

@@ -40,14 +40,12 @@ class TransformerEncoderLayer(nn.Module):
         self.activation = torch.nn.GELU()
 
     def forward(self, src):
-        src.transpose_(0, 1)
         src2 = self.self_attn(src, src, src)[0]
         src = src + self.dropout1(src2)
         src = self.norm1(src)
         src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
         src = src + self.dropout2(src2)
         src = self.norm2(src)
-        src.transpose_(0, 1)
         return src