浏览代码

Copy dht.client_mode to averager.client_mode unless it is explicitly defined

Aleksandr Borzunov 4 年之前
父节点
当前提交
06fdf1fcda
共有 2 个文件被更改,包括 19 次插入3 次删除
  1. 8 3
      hivemind/averaging/averager.py
  2. 11 0
      hivemind/dht/__init__.py

+ 8 - 3
hivemind/averaging/averager.py

@@ -61,8 +61,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
-    :param client_mode: if False (default), this averager will accept incoming requests from other peers
-            if True, the averager will only join existing groups where at least one peer has client_mode=False
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+          if True, the averager will only join existing groups where at least one peer has client_mode=False.
+          By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
           local tensors for averaging
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
@@ -106,7 +107,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         allow_state_sharing: Optional[bool] = None,
-        client_mode: bool = False,
+        client_mode: Optional[bool] = None,
         daemon: bool = True,
         daemon: bool = True,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
     ):
     ):
@@ -121,7 +122,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
+
+        if client_mode is None:
+            client_mode = dht.client_mode
         self.client_mode = client_mode
         self.client_mode = client_mode
+
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
         if self.client_mode:
         if self.client_mode:
             self.mode = AveragingMode.CLIENT
             self.mode = AveragingMode.CLIENT

+ 11 - 0
hivemind/dht/__init__.py

@@ -89,7 +89,9 @@ class DHT(mp.Process):
         self.ready = mp.Event()
         self.ready = mp.Event()
         self.daemon = daemon
         self.daemon = daemon
 
 
+        # These values will be fetched from the child process when requested
         self._peer_id = None
         self._peer_id = None
+        self._client_mode = None
         self._p2p_replica = None
         self._p2p_replica = None
 
 
         if start:
         if start:
@@ -270,6 +272,15 @@ class DHT(mp.Process):
     async def _get_peer_id(self, node: DHTNode) -> PeerID:
     async def _get_peer_id(self, node: DHTNode) -> PeerID:
         return node.peer_id
         return node.peer_id
 
 
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    async def _get_client_mode(self, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.
         Get multiaddrs of the current DHT node that should be accessible by other peers.