|
@@ -318,8 +318,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
delay_averaging=not self.auxiliary,
|
|
|
grad_scaler=grad_scaler,
|
|
|
- wait_for_trigger=partial(
|
|
|
- self._average_gradients_and_load_into_optimizer, self.scheduled_round) if self.delay_grad_averaging else None,
|
|
|
+ wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
|
|
|
+ if self.delay_grad_averaging
|
|
|
+ else None,
|
|
|
averaging_opts=dict(
|
|
|
scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
)
|