|
@@ -395,6 +395,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_control=self.scheduled_state if should_average_state else None,
|
|
|
averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
+ if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
+ self.scheduled_state.cancel()
|
|
|
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
self.scheduled_grads = self.scheduled_state = None
|