|
@@ -8,7 +8,7 @@ from typing import Callable, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
|
|
|
|
-from hivemind.averaging.control import StepControl
|
|
|
+from hivemind.averaging.control import StepControl, AveragingStage
|
|
|
from hivemind.compression import CompressionBase, NoCompression
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
@@ -414,7 +414,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.scheduled_state = None
|
|
|
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
- self.scheduled_grads = self.scheduled_state = None
|
|
|
self._should_check_synchronization_on_update = True
|
|
|
# the above line ensures that peers check for *strict* synchronization once per epoch
|
|
|
|
|
@@ -611,26 +610,21 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
def _finish_scheduled_averaging(self):
|
|
|
- if self.scheduled_grads is not None:
|
|
|
- self.scheduled_grads.weight = 0
|
|
|
- self.scheduled_grads.allow_allreduce()
|
|
|
- if self.scheduled_state is not None:
|
|
|
- self.scheduled_state.weight = 0
|
|
|
- self.scheduled_state.allow_allreduce()
|
|
|
- if self.scheduled_grads is not None:
|
|
|
- try:
|
|
|
- self.scheduled_grads.result(timeout=max(0.0, self.scheduled_grads.deadline - get_dht_time()))
|
|
|
- except BaseException as e:
|
|
|
- logger.warning(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
- if not self.scheduled_grads.done():
|
|
|
- self.scheduled_grads.cancel()
|
|
|
- if self.scheduled_state is not None:
|
|
|
- try:
|
|
|
- self.scheduled_state.result(timeout=max(0.0, self.scheduled_state.deadline - get_dht_time()))
|
|
|
- except BaseException as e:
|
|
|
- logger.warning(self.status_loglevel, f"Caught {e} while averaging state")
|
|
|
- if not self.scheduled_state.done():
|
|
|
- self.scheduled_state.cancel()
|
|
|
+ for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
+ if scheduled_round is not None:
|
|
|
+ if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
+ scheduled_round.cancel()
|
|
|
+ if scheduled_round.stage == AveragingStage.AWAITING_TRIGGER:
|
|
|
+ scheduled_round.weight = 0
|
|
|
+ scheduled_round.allow_allreduce()
|
|
|
+ for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
+ if scheduled_round is not None:
|
|
|
+ try:
|
|
|
+ scheduled_round.result(timeout=max(0.0, scheduled_round.deadline - get_dht_time()))
|
|
|
+ except BaseException as e:
|
|
|
+ logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
+ if not scheduled_round.done():
|
|
|
+ scheduled_round.cancel()
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|