Michael Diskin 4 anos atrás
pai
commit
f904141807
1 arquivos alterados com 11 adições e 0 exclusões
  1. 11 0
      hivemind/optim/collaborative.py

+ 11 - 0
hivemind/optim/collaborative.py

@@ -201,6 +201,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"]["collaborative_step"]
+            del state_dict["state"]["collaborative_step"]
+        return super().load_state_dict(state_dict)
+
     def step(self, batch_size: Optional[int] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters