|
@@ -137,7 +137,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
params: Optional[Union[Parameters, ParamGroups]] = None,
|
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
|
matchmaking_time: Optional[float] = 15.0,
|
|
|
- averaging_timeout: Optional[float] = 300.0,
|
|
|
+ averaging_timeout: Optional[float] = 60.0,
|
|
|
load_state_timeout: float = 600.0,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
offload_optimizer: Optional[bool] = None,
|
|
@@ -397,7 +397,12 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
|
|
|
should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
|
|
|
|
|
|
- if should_average_state:
|
|
|
+ if should_average_state and self.scheduled_state is not None:
|
|
|
+ if self.scheduled_state.triggered or self.scheduled_state.done():
|
|
|
+ logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
|
|
|
+ f"was already used elsewhere: {self.scheduled_state}")
|
|
|
+ self.scheduled_state = None
|
|
|
+
|
|
|
self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
|
|
|
|
|
|
self.state_averager.step(
|
|
@@ -412,10 +417,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
|
|
|
- if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
- self.scheduled_state.cancel()
|
|
|
- self.scheduled_state = None
|
|
|
-
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
self._should_check_synchronization_on_update = True
|
|
|
# the above line ensures that peers check for *strict* synchronization once per epoch
|
|
@@ -437,7 +438,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
assert grad_scaler.unscale_(self)
|
|
|
|
|
|
if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
|
|
|
- logger.debug(self.status_loglevel, f"Discarding previous matchmaking results: {self.scheduled_grads}")
|
|
|
+ logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
|
|
|
+ f"was already used elsewhere: {self.scheduled_state}")
|
|
|
self.scheduled_grads = None
|
|
|
|
|
|
began_averaging_gradients = False
|
|
@@ -614,17 +616,16 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not scheduled_round.triggered:
|
|
|
scheduled_round.weight = 0
|
|
|
scheduled_round.allow_allreduce()
|
|
|
- if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
- scheduled_round.cancel()
|
|
|
for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
if scheduled_round is not None and not scheduled_round.done():
|
|
|
try:
|
|
|
time_to_deadline = scheduled_round.deadline - get_dht_time()
|
|
|
- scheduled_round.result(timeout=max(0.0, min(time_to_deadline, self.shutdown_timeout)))
|
|
|
+ scheduled_round.result(timeout=max(0.0, time_to_deadline))
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
if not scheduled_round.done():
|
|
|
scheduled_round.cancel()
|
|
|
+ self.scheduled_grads = self.scheduled_state = None
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|