|
@@ -232,9 +232,43 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.collaboration_state_updated.set()
|
|
|
self.update_scheduler()
|
|
|
|
|
|
- logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
|
+ logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
|
|
|
|
- return group_info
|
|
|
+ return group_info
|
|
|
+
|
|
|
+ def step_aux(self, **kwargs):
|
|
|
+ """
|
|
|
+ Find and assist other peers in averaging without sending local gradients.
|
|
|
+
|
|
|
+ :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
+ """
|
|
|
+
|
|
|
+ if not self.collaboration_state.ready_for_step:
|
|
|
+ return
|
|
|
+
|
|
|
+ logger.log(self.status_loglevel,
|
|
|
+ f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
|
|
|
+ self.collaboration_state = self.fetch_collaboration_state()
|
|
|
+ self.collaboration_state_updated.set()
|
|
|
+
|
|
|
+ with self.lock_collaboration_state:
|
|
|
+ # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
+ current_step, group_info = self.averager.local_step, None
|
|
|
+ try:
|
|
|
+ group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
|
|
|
+ if group_info:
|
|
|
+ logger.log(self.status_loglevel,
|
|
|
+ f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
+ except BaseException as e:
|
|
|
+ logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
+
|
|
|
+ self.collaboration_state.register_step(current_step + 1)
|
|
|
+ self.averager.local_step = current_step + 1
|
|
|
+ self.collaboration_state_updated.set()
|
|
|
+
|
|
|
+ logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
|
+
|
|
|
+ return group_info
|
|
|
|
|
|
def _grad_buffers(self) -> Iterator[torch.Tensor]:
|
|
|
""" pytorch-internal gradient buffers """
|