|
@@ -682,15 +682,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not scheduled_round.triggered:
|
|
|
scheduled_round.weight = 0
|
|
|
scheduled_round.allow_allreduce()
|
|
|
- for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
- if scheduled_round is not None and not scheduled_round.done():
|
|
|
- try:
|
|
|
- time_to_deadline = scheduled_round.deadline - get_dht_time()
|
|
|
- scheduled_round.result(timeout=max(0.0, time_to_deadline))
|
|
|
- except BaseException as e:
|
|
|
- logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
- if not scheduled_round.done():
|
|
|
- scheduled_round.cancel()
|
|
|
self.scheduled_grads = self.scheduled_state = None
|
|
|
|
|
|
def state_dict(self) -> dict:
|