|
@@ -259,12 +259,14 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
)
|
|
|
logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
|
|
|
except BaseException as e:
|
|
|
- logger.log(self.status_loglevel, f"Averaging failed with {repr(e)}")
|
|
|
+ logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}, using local grads")
|
|
|
+ self.grad_averager.load_accumulators_into_averager_()
|
|
|
|
|
|
else:
|
|
|
if self.scheduled_round is not None:
|
|
|
self.scheduled_round.cancel()
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
|
|
|
+ self.grad_averager.load_accumulators_into_averager_()
|
|
|
|
|
|
assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
|