|
@@ -385,7 +385,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
|
|
|
if not began_averaging_gradients:
|
|
|
pass # failed to start gradient averaging due to an internal error
|
|
|
- if self.delay_grad_averaging:
|
|
|
+ elif self.delay_grad_averaging:
|
|
|
# if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
|
|
|
wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
|
|
|
else:
|
|
@@ -437,13 +437,13 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
with grad_scaler.running_global_step():
|
|
|
assert grad_scaler.unscale_(self)
|
|
|
|
|
|
+ began_averaging_gradients = False
|
|
|
if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
|
|
|
logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
|
|
|
f"was already used elsewhere: {self.scheduled_state}")
|
|
|
self.scheduled_grads = None
|
|
|
|
|
|
- began_averaging_gradients = False
|
|
|
- if self.tracker.global_progress.num_peers > 1:
|
|
|
+ elif self.tracker.global_progress.num_peers > 1:
|
|
|
try:
|
|
|
self.scheduled_grads = self.grad_averager.step(
|
|
|
control=self.scheduled_grads, reset_accumulators=True, wait=False
|
|
@@ -582,8 +582,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
|
- self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
- self._finish_background_averaging()
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
while True:
|