Sfoglia il codice sorgente

switch from explicit __del__ to shutdown,
switch from warnings to logger.warning

justheuristic 4 anni fa
parent
commit
43be165803
3 ha cambiato i file con 24 aggiunte e 18 eliminazioni
  1. 3 5
      hivemind/dht/node.py
  2. 3 5
      hivemind/dht/protocol.py
  3. 18 8
      hivemind/p2p/p2p_daemon.py

+ 3 - 5
hivemind/dht/node.py

@@ -65,11 +65,12 @@ class DHTNode:
 
     """
     # fmt:off
-    node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
+    #TODO remove port
+    node_id: DHTID; is_alive: bool; endpoint: Endpoint; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
-    blacklist: Blacklist
+    blacklist: Blacklist;
     # fmt:on
 
     @classmethod
@@ -182,9 +183,6 @@ class DHTNode:
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         super().__init__()
 
-    def __del__(self):
-        self.protocol.__del__()
-
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         self.is_alive = False

+ 3 - 5
hivemind/dht/protocol.py

@@ -22,7 +22,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Tuple[Tuple[str, Any]]; server: P2P
+    channel_options: Tuple[Tuple[str, Any]]; client: P2P; server: Optional[P2P]
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     record_validator: Optional[RecordValidatorBase]
     # fmt:on
@@ -89,17 +89,15 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
 
-    def __del__(self):
-        self.client.__del__()
-
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         if self.server:
             await self.server.stop_listening()
+            await self.client.shutdown(timeout)
         else:
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
-    class DHTStub:
+    class DHTStub: #TODO refactor this
         def __init__(self, protocol: DHTProtocol, peer: Endpoint):
             self.protocol = protocol
             self.peer = peer

+ 18 - 8
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,6 @@ import pickle
 import signal
 import subprocess
 import typing as tp
-import warnings
 
 import google.protobuf
 from multiaddr import Multiaddr
@@ -14,6 +13,10 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
 
 from hivemind.utils.networking import find_open_port
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 class P2PContext(object):
@@ -68,8 +71,8 @@ class P2P(object):
             try:
                 self._initialize(proc_args)
                 await self._identify_client(P2P.RETRY_DELAY * (2 ** try_count))
-            except Exception as exc:
-                warnings.warn("Failed to initialize p2p daemon: " + str(exc), RuntimeWarning)
+            except Exception as e:
+                logger.debug(f"Failed to initialize p2p daemon: {e}", RuntimeWarning)
                 self._kill_child()
                 if try_count == P2P.NUM_RETRIES - 1:
                     raise
@@ -173,7 +176,8 @@ class P2P(object):
             try:
                 request = await P2P.receive_data(reader)
             except P2P.IncompleteRead:
-                warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
+                if self.is_alive:
+                    logger.warning("Incomplete read while receiving request from peer")
                 writer.close()
                 return
             try:
@@ -200,11 +204,10 @@ class P2P(object):
                 try:
                     request = await P2P.receive_protobuf(in_proto_type, reader)
                 except P2P.IncompleteRead:
-                    warnings.warn("Incomplete read while receiving request from peer",
-                                  RuntimeWarning)
+                    logger.warning("Incomplete read while receiving request from peer")
                     return
                 except google.protobuf.message.DecodeError as error:
-                    warnings.warn(repr(error), RuntimeWarning)
+                    logger.warning(repr(error))
                     return
 
                 context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
@@ -293,6 +296,13 @@ class P2P(object):
     def __del__(self):
         self._kill_child()
 
+    @property
+    def is_alive(self):
+        return self._child.is_alive
+
+    async def shutdown(self, timeout=None):
+        await asyncio.get_event_loop().run_in_executor(None, self._kill_child)
+
     def _kill_child(self):
         if self._child is not None and self._child.poll() is None:
             self._child.kill()
@@ -303,7 +313,7 @@ class P2P(object):
             self._kill_child()
 
         signal.signal(signal.SIGTERM, _handler)
-        await asyncio.Event().wait()
+        await asyncio.Event().wait() #TODO justify this
 
     def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
         proc_args = []