|
@@ -61,8 +61,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:param bandwidth: if specified, this value represents the network bandwidth available to averager.
|
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
|
If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
|
|
|
- :param client_mode: if False (default), this averager will accept incoming requests from other peers
|
|
|
- if True, the averager will only join existing groups where at least one peer has client_mode=False
|
|
|
+ :param client_mode: if False, this averager will accept incoming requests from other peers.
|
|
|
+ if True, the averager will only join existing groups where at least one peer has client_mode=False.
|
|
|
+ By default, this flag is copied from DHTNode inside the ``dht`` instance.
|
|
|
:param auxiliary: if this flag is specified, averager.step will only assist others without sending
|
|
|
local tensors for averaging
|
|
|
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
@@ -106,7 +107,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
min_vector_size: int = 0,
|
|
|
auxiliary: bool = False,
|
|
|
allow_state_sharing: Optional[bool] = None,
|
|
|
- client_mode: bool = False,
|
|
|
+ client_mode: Optional[bool] = None,
|
|
|
daemon: bool = True,
|
|
|
shutdown_timeout: float = 5,
|
|
|
):
|
|
@@ -121,7 +122,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
+
|
|
|
+ if client_mode is None:
|
|
|
+ client_mode = dht.client_mode
|
|
|
self.client_mode = client_mode
|
|
|
+
|
|
|
self._parent_pid = os.getpid()
|
|
|
if self.client_mode:
|
|
|
self.mode = AveragingMode.CLIENT
|