소스 검색

Copy dht.client_mode to averager.client_mode unless it is explicitly defined

Aleksandr Borzunov 4 년 전
부모
커밋
06fdf1fcda
2개의 변경된 파일19개의 추가작업 그리고 3개의 파일을 삭제
  1. 8 3
      hivemind/averaging/averager.py
  2. 11 0
      hivemind/dht/__init__.py

+ 8 - 3
hivemind/averaging/averager.py

@@ -61,8 +61,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
-    :param client_mode: if False (default), this averager will accept incoming requests from other peers
-            if True, the averager will only join existing groups where at least one peer has client_mode=False
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+          if True, the averager will only join existing groups where at least one peer has client_mode=False.
+          By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
@@ -106,7 +107,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
-        client_mode: bool = False,
+        client_mode: Optional[bool] = None,
         daemon: bool = True,
         shutdown_timeout: float = 5,
     ):
@@ -121,7 +122,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         super().__init__()
         self.dht = dht
+
+        if client_mode is None:
+            client_mode = dht.client_mode
         self.client_mode = client_mode
+
         self._parent_pid = os.getpid()
         if self.client_mode:
             self.mode = AveragingMode.CLIENT

+ 11 - 0
hivemind/dht/__init__.py

@@ -89,7 +89,9 @@ class DHT(mp.Process):
         self.ready = mp.Event()
         self.daemon = daemon
 
+        # These values will be fetched from the child process when requested
         self._peer_id = None
+        self._client_mode = None
         self._p2p_replica = None
 
         if start:
@@ -270,6 +272,15 @@ class DHT(mp.Process):
     async def _get_peer_id(self, node: DHTNode) -> PeerID:
         return node.peer_id
 
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    async def _get_client_mode(self, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.