|
@@ -83,7 +83,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
|
|
batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
|
|
min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
|
|
min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
|
|
expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
|
|
expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
|
|
- metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
|
|
|
|
|
|
+ metadata_expiration: float = 60.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
|
|
reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
|
|
reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
|
|
client_mode: bool = False, verbose: bool = False, **kwargs):
|
|
client_mode: bool = False, verbose: bool = False, **kwargs):
|
|
super().__init__(opt, dht)
|
|
super().__init__(opt, dht)
|
|
@@ -193,8 +193,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
if group_info:
|
|
if group_info:
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
- except Exception as e:
|
|
|
|
- logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {e}.")
|
|
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
|
|
|
else:
|
|
else:
|
|
logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
|
|
logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
|