瀏覽代碼

use averaging priority

justheuristic 3 年之前
父節點
當前提交
d4b13e085c
共有 1 個文件被更改,包括 7 次插入0 次删除
  1. 7 0
      hivemind/optim/experimental/optimizer.py

+ 7 - 0
hivemind/optim/experimental/optimizer.py

@@ -338,6 +338,10 @@ class Optimizer(torch.optim.Optimizer):
                 self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
                 self._should_check_synchronization_on_update = True
 
+            if not self.client_mode:
+                self.grad_averager.state_sharing_priority = self.local_epoch
+                self.state_averager.state_sharing_priority = self.local_epoch
+
             logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
         return loss
 
@@ -398,6 +402,9 @@ class Optimizer(torch.optim.Optimizer):
                 self.state_averager.local_epoch = self.tracker.global_epoch
 
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+            if not self.client_mode:
+                self.grad_averager.state_sharing_priority = self.local_epoch
+                self.state_averager.state_sharing_priority = self.local_epoch
             self.grad_averager.reset_accumulated_grads_()
 
     def state_dict(self) -> dict: