|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
-from typing import Optional, Union
|
|
|
+from typing import Optional, Union, Callable
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -187,14 +187,24 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
"""If true, peer will discard local progress and attempt to download state from peers."""
|
|
|
return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
|
|
|
|
|
|
- def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
|
|
|
+ def step(self,
|
|
|
+ closure: Optional[Callable[[], torch.Tensor]] = None,
|
|
|
+ batch_size: Optional[int] = None,
|
|
|
+ grad_scaler: Optional[HivemindGradScaler] = None,
|
|
|
+ **kwargs):
|
|
|
"""
|
|
|
Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
|
|
|
|
|
|
+ :param closure: A closure that reevaluates the model and returns the loss
|
|
|
:param batch_size: optional override for batch_size_per_step from init
|
|
|
:param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
+ loss = None
|
|
|
+ if closure is not None:
|
|
|
+ with torch.enable_grad():
|
|
|
+ loss = closure()
|
|
|
+
|
|
|
if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
|
|
|
raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
|
|
|
if self.batch_size_per_step is None and batch_size is None:
|
|
@@ -203,13 +213,13 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
if self.should_load_state_from_peers:
|
|
|
self.load_state_from_peers()
|
|
|
- return
|
|
|
+ return loss
|
|
|
|
|
|
if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
|
|
|
logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
|
|
|
self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
- return
|
|
|
+ return loss
|
|
|
|
|
|
self.grad_averager.accumulate_grads_(batch_size)
|
|
|
self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
|
|
@@ -226,7 +236,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
|
|
|
|
|
|
if not self.tracker.ready_to_update_epoch:
|
|
|
- return
|
|
|
+ return loss
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
|
|
@@ -275,6 +285,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
|
+ return loss
|
|
|
|
|
|
def step_aux(self, **kwargs):
|
|
|
"""
|