|
@@ -201,6 +201,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.reset_accumulated_grads_()
|
|
|
self.update_scheduler()
|
|
|
|
|
|
+ def state_dict(self) -> dict:
|
|
|
+ state_dict = super().state_dict()
|
|
|
+ state_dict["state"]["collaborative_step"] = self.local_step
|
|
|
+ return state_dict
|
|
|
+
|
|
|
+ def load_state_dict(self, state_dict: dict):
|
|
|
+ if "collaborative_step" in state_dict["state"]:
|
|
|
+ self.averager.local_step = state_dict["state"]["collaborative_step"]
|
|
|
+ del state_dict["state"]["collaborative_step"]
|
|
|
+ return super().load_state_dict(state_dict)
|
|
|
+
|
|
|
def step(self, batch_size: Optional[int] = None, **kwargs):
|
|
|
"""
|
|
|
Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
|