|
@@ -85,6 +85,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
|
|
|
|
+ :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
|
|
:param scheduler: if specified, use this scheduler to update optimizer learning rate
|
|
:param scheduler: if specified, use this scheduler to update optimizer learning rate
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
@@ -114,6 +115,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
performance_ema_alpha: float = 0.1,
|
|
performance_ema_alpha: float = 0.1,
|
|
metadata_expiration: float = 60.0,
|
|
metadata_expiration: float = 60.0,
|
|
averaging_timeout: Optional[float] = None,
|
|
averaging_timeout: Optional[float] = None,
|
|
|
|
+ load_state_timeout: float = 600.0,
|
|
step_tolerance: int = 1,
|
|
step_tolerance: int = 1,
|
|
reuse_grad_buffers: bool = False,
|
|
reuse_grad_buffers: bool = False,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
@@ -137,7 +139,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
default_refresh_period,
|
|
default_refresh_period,
|
|
)
|
|
)
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
- self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
|
|
|
|
|
|
+ self.averaging_timeout, self.load_state_timeout, self.metadata_expiration = \
|
|
|
|
+ averaging_timeout, load_state_timeout, metadata_expiration
|
|
self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
|
|
self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
|
|
self.client_mode, self.step_tolerance = client_mode, step_tolerance
|
|
self.client_mode, self.step_tolerance = client_mode, step_tolerance
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
@@ -185,7 +188,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
def load_state_from_peers(self, **kwargs):
|
|
def load_state_from_peers(self, **kwargs):
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
with self.lock_collaboration_state:
|
|
with self.lock_collaboration_state:
|
|
- self.averager.load_state_from_peers(**kwargs)
|
|
|
|
|
|
+ while True:
|
|
|
|
+ try:
|
|
|
|
+ self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.exception(f"Failed to load state from peers: {e}, will retry now")
|
|
|
|
+ continue
|
|
|
|
+
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.reset_accumulated_grads_()
|
|
self.reset_accumulated_grads_()
|
|
self.update_scheduler()
|
|
self.update_scheduler()
|