|
@@ -582,8 +582,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
|
- if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
- self.scheduled_grads.cancel()
|
|
|
+ self._finish_scheduled_averaging()
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
while True:
|
|
@@ -605,18 +604,32 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.client_mode:
|
|
|
self.state_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
- self._cancel_scheduled_averaging()
|
|
|
-
|
|
|
if self.use_gradient_averaging:
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
if not self.client_mode:
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
- def _cancel_scheduled_averaging(self):
|
|
|
- if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
- self.scheduled_grads.cancel()
|
|
|
- if self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
- self.scheduled_state.cancel()
|
|
|
+ def _finish_scheduled_averaging(self):
|
|
|
+ if self.scheduled_grads is not None:
|
|
|
+ self.scheduled_grads.weight = 0
|
|
|
+ self.scheduled_grads.allow_allreduce()
|
|
|
+ if self.scheduled_state is not None:
|
|
|
+ self.scheduled_state.weight = 0
|
|
|
+ self.scheduled_state.allow_allreduce()
|
|
|
+ if self.scheduled_grads is not None:
|
|
|
+ try:
|
|
|
+ self.scheduled_grads.result(timeout=max(0.0, self.scheduled_grads.deadline - get_dht_time()))
|
|
|
+ except BaseException as e:
|
|
|
+ logger.warning(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
+ if not self.scheduled_grads.done():
|
|
|
+ self.scheduled_grads.cancel()
|
|
|
+ if self.scheduled_state is not None:
|
|
|
+ try:
|
|
|
+ self.scheduled_state.result(timeout=max(0.0, self.scheduled_state.deadline - get_dht_time()))
|
|
|
+ except BaseException as e:
|
|
|
+ logger.warning(self.status_loglevel, f"Caught {e} while averaging state")
|
|
|
+ if not self.scheduled_state.done():
|
|
|
+ self.scheduled_state.cancel()
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|
|
@@ -659,7 +672,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def shutdown(self):
|
|
|
logger.debug("Sending goodbye to peers...")
|
|
|
- self._cancel_scheduled_averaging()
|
|
|
+ self._finish_scheduled_averaging()
|
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
|
logger.debug("Shutting down averagers...")
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|