|
@@ -8,7 +8,7 @@ from typing import Callable, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
|
|
|
|
-from hivemind.averaging.control import StepControl, AveragingStage
|
|
|
+from hivemind.averaging.control import AveragingStage, StepControl
|
|
|
from hivemind.compression import CompressionBase, NoCompression
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
@@ -234,6 +234,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return TrainingStateAverager(
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_state_averager",
|
|
|
+ min_matchmaking_time=self.matchmaking_time,
|
|
|
allreduce_timeout=self.averaging_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
offload_optimizer=self.offload_optimizer,
|
|
@@ -251,6 +252,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
|
+ min_matchmaking_time=self.matchmaking_time,
|
|
|
allreduce_timeout=self.averaging_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
client_mode=self.client_mode,
|
|
@@ -409,6 +411,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_control=self.scheduled_state if should_average_state else None,
|
|
|
averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
+
|
|
|
if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
self.scheduled_state.cancel()
|
|
|
self.scheduled_state = None
|
|
@@ -443,7 +446,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.scheduled_grads = self.grad_averager.step(
|
|
|
control=self.scheduled_grads, reset_accumulators=True, wait=False
|
|
|
)
|
|
|
- assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
|
|
|
began_averaging_gradients = True
|
|
|
except BaseException as e:
|
|
|
logger.exception(e)
|
|
@@ -479,10 +481,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
|
|
|
eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
|
|
|
logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
|
|
|
- scheduled_time = self.tracker.estimated_next_update_time
|
|
|
- if self.client_mode:
|
|
|
- scheduled_time = get_dht_time() + self.averaging_timeout
|
|
|
- self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
|
|
|
+ self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
|
|
|
|
|
|
def _maybe_schedule_state_averaging(self) -> None:
|
|
|
"""If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
|
|
@@ -496,14 +495,16 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
eta_seconds_to_averaging = estimated_time - get_dht_time()
|
|
|
|
|
|
if eta_seconds_to_averaging <= self.matchmaking_time:
|
|
|
+ if self.delay_state_averaging:
|
|
|
+ # wait for previous averaging to finish before starting a new one
|
|
|
+ self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
|
|
|
+
|
|
|
min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
|
|
|
actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
|
|
|
logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
|
|
|
- if self.client_mode:
|
|
|
- estimated_time = get_dht_time() + self.averaging_timeout
|
|
|
self.scheduled_state = self.state_averager.schedule_step(
|
|
|
- estimated_time, gather=next_epoch, timeout=self.averaging_timeout
|
|
|
+ gather=next_epoch, timeout=self.averaging_timeout
|
|
|
)
|
|
|
|
|
|
def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
|
|
@@ -582,7 +583,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
|
- self._finish_scheduled_averaging()
|
|
|
+ self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
+ self._finish_background_averaging()
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
while True:
|
|
@@ -609,22 +611,23 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.client_mode:
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
- def _finish_scheduled_averaging(self):
|
|
|
+ def _finish_background_averaging(self):
|
|
|
for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
if scheduled_round is not None:
|
|
|
- if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
- scheduled_round.cancel()
|
|
|
- if scheduled_round.stage == AveragingStage.AWAITING_TRIGGER:
|
|
|
+ if not scheduled_round.triggered:
|
|
|
scheduled_round.weight = 0
|
|
|
scheduled_round.allow_allreduce()
|
|
|
+ if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
+ scheduled_round.cancel()
|
|
|
for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
- if scheduled_round is not None:
|
|
|
+ if scheduled_round is not None and not scheduled_round.done():
|
|
|
try:
|
|
|
- scheduled_round.result(timeout=max(0.0, scheduled_round.deadline - get_dht_time()))
|
|
|
+ time_to_deadline = scheduled_round.deadline - get_dht_time()
|
|
|
+ scheduled_round.result(timeout=max(0.0, min(time_to_deadline, self.shutdown_timeout)))
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
- if not scheduled_round.done():
|
|
|
- scheduled_round.cancel()
|
|
|
+ if not scheduled_round.done():
|
|
|
+ scheduled_round.cancel()
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|
|
@@ -667,10 +670,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def shutdown(self):
|
|
|
logger.debug("Sending goodbye to peers...")
|
|
|
- self._finish_scheduled_averaging()
|
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
|
- logger.debug("Shutting down averagers...")
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
+ self._finish_background_averaging()
|
|
|
+ logger.debug("Shutting down averagers...")
|
|
|
self.state_averager.shutdown()
|
|
|
if self.use_gradient_averaging:
|
|
|
self.grad_averager.shutdown()
|