|
@@ -67,7 +67,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
|
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
|
- :param epoch_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
|
|
|
+ :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many epochs
|
|
|
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
|
|
|
:param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
|
|
|
:param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
|
|
@@ -95,19 +95,20 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
|
matchmaking_time: Optional[float] = 15.0,
|
|
|
averaging_timeout: Optional[float] = 300.0,
|
|
|
+ average_state_every: int = 1,
|
|
|
load_state_timeout: float = 600.0,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
- epoch_tolerance: int = 1,
|
|
|
delay_optimizer_step: bool = False,
|
|
|
client_mode: bool = None,
|
|
|
averager_opts: Optional[dict] = None,
|
|
|
tracker_opts: Optional[dict] = None,
|
|
|
verbose: bool = False,
|
|
|
):
|
|
|
- self.dht, self.prefix, self.epoch_tolerance = dht, prefix, epoch_tolerance
|
|
|
+ self.dht, self.prefix = dht, prefix
|
|
|
self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
|
self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
|
|
|
+ self.average_state_every = average_state_every
|
|
|
|
|
|
self.client_mode = client_mode if client_mode is not None else self.dht.client_mode
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
@@ -118,6 +119,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
)
|
|
|
self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
|
|
|
self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
|
|
|
+ self._last_synchronized_time = get_dht_time()
|
|
|
self._schema_hash = self._compute_schema_hash()
|
|
|
self._parent_pid = os.getpid()
|
|
|
|
|
@@ -182,10 +184,12 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
def local_epoch(self) -> int:
|
|
|
return self.state_averager.local_epoch
|
|
|
|
|
|
- @property
|
|
|
- def should_load_state_from_peers(self) -> bool:
|
|
|
+ def should_load_state_from_peers(self, new_epoch: bool = False) -> bool:
|
|
|
"""If true, peer will discard local progress and attempt to download state from peers."""
|
|
|
- return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
|
|
|
+ if new_epoch:
|
|
|
+ return self.local_epoch != self.tracker.global_epoch
|
|
|
+ else:
|
|
|
+ return self.local_epoch < self.tracker.global_epoch - 1
|
|
|
|
|
|
def step(
|
|
|
self,
|
|
@@ -213,7 +217,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
with torch.enable_grad():
|
|
|
loss = closure()
|
|
|
|
|
|
- if self.should_load_state_from_peers:
|
|
|
+ if self.should_load_state_from_peers():
|
|
|
+ logger.log(self.status_loglevel, "Peer is out of sync.")
|
|
|
self.load_state_from_peers()
|
|
|
return loss
|
|
|
|
|
@@ -277,7 +282,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
optimizer_step=True,
|
|
|
delay_optimizer_step=self.delay_optimizer_step,
|
|
|
grad_scaler=grad_scaler,
|
|
|
- averaging_round=need_averaging,
|
|
|
+ averaging_round=need_averaging and self.tracker.global_epoch % self.average_state_every == 0,
|
|
|
delay_averaging=True,
|
|
|
averaging_opts=dict(
|
|
|
scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
@@ -288,6 +293,12 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
+
|
|
|
+ if self.should_load_state_from_peers(new_epoch=True):
|
|
|
+ logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
|
|
|
+ self.load_state_from_peers()
|
|
|
+ return loss
|
|
|
+
|
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
|
return loss
|
|
|
|
|
@@ -330,7 +341,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.exception(f"Failed to load state from peers: {e}, retrying ...")
|
|
|
continue
|
|
|
|
|
|
- if self.tracker.global_epoch - self.epoch_tolerance <= self.local_epoch < self.tracker.global_epoch:
|
|
|
+ if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
|
|
|
logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
|
|
|
self.state_averager.local_epoch = self.tracker.global_epoch
|
|
|
|