Michael Diskin před 4 roky
rodič
revize
5b28a9814b
1 změnil soubory, kde provedl 11 přidání a 0 odebrání
  1. 11 0
      hivemind/optim/collaborative.py

+ 11 - 0
hivemind/optim/collaborative.py

@@ -190,6 +190,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"]["collaborative_step"]
+            del state_dict["state"]["collaborative_step"]
+        return super().load_state_dict(state_dict)
+
     def step(self, batch_size: Optional[int] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters