소스 검색

better shutdown script

justheuristic 3 년 전
부모
커밋
6377ea7eec
1개의 변경된 파일23개의 추가작업 그리고 10개의 파일을 삭제
  1. 23 10
      hivemind/optim/experimental/optimizer.py

+ 23 - 10
hivemind/optim/experimental/optimizer.py

@@ -582,8 +582,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
-        if self.scheduled_grads is not None and not self.scheduled_grads.done():
-            self.scheduled_grads.cancel()
+        self._finish_scheduled_averaging()
 
         with self.tracker.pause_updates():
             while True:
@@ -605,18 +604,32 @@ class Optimizer(torch.optim.Optimizer):
             if not self.client_mode:
                 self.state_averager.state_sharing_priority = self.local_epoch
 
-            self._cancel_scheduled_averaging()
-
             if self.use_gradient_averaging:
                 self.grad_averager.reset_accumulated_grads_()
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
-    def _cancel_scheduled_averaging(self):
-        if self.scheduled_grads is not None and not self.scheduled_grads.done():
-            self.scheduled_grads.cancel()
-        if self.scheduled_state is not None and not self.scheduled_state.done():
-            self.scheduled_state.cancel()
+    def _finish_scheduled_averaging(self):
+        if self.scheduled_grads is not None:
+            self.scheduled_grads.weight = 0
+            self.scheduled_grads.allow_allreduce()
+        if self.scheduled_state is not None:
+            self.scheduled_state.weight = 0
+            self.scheduled_state.allow_allreduce()
+        if self.scheduled_grads is not None:
+            try:
+                self.scheduled_grads.result(timeout=max(0.0, self.scheduled_grads.deadline - get_dht_time()))
+            except BaseException as e:
+                logger.warning(self.status_loglevel, f"Caught {e} while averaging gradients")
+            if not self.scheduled_grads.done():
+                self.scheduled_grads.cancel()
+        if self.scheduled_state is not None:
+            try:
+                self.scheduled_state.result(timeout=max(0.0, self.scheduled_state.deadline - get_dht_time()))
+            except BaseException as e:
+                logger.warning(self.status_loglevel, f"Caught {e} while averaging state")
+            if not self.scheduled_state.done():
+                self.scheduled_state.cancel()
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
@@ -659,7 +672,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
-        self._cancel_scheduled_averaging()
+        self._finish_scheduled_averaging()
         self.tracker.shutdown(self.shutdown_timeout)
         logger.debug("Shutting down averagers...")
         self.state_averager.step(wait_for_delayed_updates=True)