Răsfoiți Sursa

[benchmarking WIP] avoid deleting non-triggered StepControl

justheuristic 3 ani în urmă
părinte
comite
f4e07fedc6
2 a modificat fișierele cu 18 adăugiri și 23 ștergeri
  1. 2 1
      hivemind/averaging/control.py
  2. 16 22
      hivemind/optim/experimental/optimizer.py

+ 2 - 1
hivemind/averaging/control.py

@@ -1,3 +1,4 @@
+import os
 import struct
 from enum import Enum
 from typing import Optional
@@ -145,7 +146,7 @@ class StepControl(MPFuture):
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
 
     def __del__(self):
-        if not self.triggered:
+        if os.getpid() == self._origin_pid and not self.triggered:
             logger.warning("Deleted an averaging StepControl, but the step was not triggered. This may cause other "
                            "peers to fail an averaging round via TimeoutError.")
         super().__del__()

+ 16 - 22
hivemind/optim/experimental/optimizer.py

@@ -8,7 +8,7 @@ from typing import Callable, Optional, Sequence, Union
 
 import torch
 
-from hivemind.averaging.control import StepControl
+from hivemind.averaging.control import StepControl, AveragingStage
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.experimental.grad_averager import GradientAverager
@@ -414,7 +414,6 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_state = None
 
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
-            self.scheduled_grads = self.scheduled_state = None
             self._should_check_synchronization_on_update = True
             # the above line ensures that peers check for *strict* synchronization once per epoch
 
@@ -611,26 +610,21 @@ class Optimizer(torch.optim.Optimizer):
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
     def _finish_scheduled_averaging(self):
-        if self.scheduled_grads is not None:
-            self.scheduled_grads.weight = 0
-            self.scheduled_grads.allow_allreduce()
-        if self.scheduled_state is not None:
-            self.scheduled_state.weight = 0
-            self.scheduled_state.allow_allreduce()
-        if self.scheduled_grads is not None:
-            try:
-                self.scheduled_grads.result(timeout=max(0.0, self.scheduled_grads.deadline - get_dht_time()))
-            except BaseException as e:
-                logger.warning(self.status_loglevel, f"Caught {e} while averaging gradients")
-            if not self.scheduled_grads.done():
-                self.scheduled_grads.cancel()
-        if self.scheduled_state is not None:
-            try:
-                self.scheduled_state.result(timeout=max(0.0, self.scheduled_state.deadline - get_dht_time()))
-            except BaseException as e:
-                logger.warning(self.status_loglevel, f"Caught {e} while averaging state")
-            if not self.scheduled_state.done():
-                self.scheduled_state.cancel()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                if scheduled_round.stage == AveragingStage.AWAITING_TRIGGER:
+                    scheduled_round.weight = 0
+                    scheduled_round.allow_allreduce()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                try:
+                    scheduled_round.result(timeout=max(0.0, scheduled_round.deadline - get_dht_time()))
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
+            if not scheduled_round.done():
+                scheduled_round.cancel()
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()