|
@@ -14,8 +14,9 @@ from hivemind.dht import DHT
|
|
|
from hivemind.dht.crypto import RSASignatureValidator
|
|
|
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
|
|
|
from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
+from hivemind.optim.grad_scaler import HivemindGradScaler
|
|
|
from hivemind.optim.performance_ema import PerformanceEMA
|
|
|
-from hivemind.utils import Endpoint, get_dht_time, get_logger
|
|
|
+from hivemind.utils import get_dht_time, get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
|
|
@@ -147,6 +148,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
|
self.averager = self._make_averager(**kwargs)
|
|
|
|
|
|
+ self._step_supports_amp_scaling = self.reuse_grad_buffers # enable custom execution with torch GradScaler
|
|
|
+
|
|
|
self.training_progress_key = f"{self.prefix}_progress"
|
|
|
self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
|
|
|
self.local_updates_accumulated = 0 # a number of calls to step() since last optimizer update
|
|
@@ -197,6 +200,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
try:
|
|
|
self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
|
|
|
break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ raise
|
|
|
except BaseException as e:
|
|
|
logger.exception(f"Failed to load state from peers: {e}, retrying ...")
|
|
|
continue
|
|
@@ -205,13 +210,16 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.reset_accumulated_grads_()
|
|
|
self.update_scheduler()
|
|
|
|
|
|
- def step(self, batch_size: Optional[int] = None, **kwargs):
|
|
|
+ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
|
|
|
"""
|
|
|
Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
|
|
|
|
|
|
:param batch_size: optional override for batch_size_per_step from init
|
|
|
+ :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
+ if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
|
|
|
+ raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
|
|
|
if self.batch_size_per_step is None:
|
|
|
if batch_size is None:
|
|
|
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
|
|
@@ -227,6 +235,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.averager.local_step = self.collaboration_state.optimizer_step
|
|
|
logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
|
|
|
|
|
|
+ if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
|
|
|
+ logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
|
|
|
+ self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
|
+ self.reset_accumulated_grads_()
|
|
|
+ self.should_report_progress.set()
|
|
|
+ return
|
|
|
+
|
|
|
if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
|
|
|
logger.warning(
|
|
|
f"Training step took {get_dht_time() - self.last_step_time}, "
|
|
@@ -251,6 +266,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
|
|
|
+ if grad_scaler is not None:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.unscale_(self)
|
|
|
+
|
|
|
current_step, group_info = self.averager.local_step, None
|
|
|
|
|
|
if self.collaboration_state.num_peers > 1:
|
|
@@ -279,13 +298,21 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
|
|
|
)
|
|
|
|
|
|
- self.opt.step()
|
|
|
+ if grad_scaler is not None:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.step(self)
|
|
|
+ else:
|
|
|
+ self.opt.step()
|
|
|
+
|
|
|
self.reset_accumulated_grads_()
|
|
|
self.local_samples_accumulated = self.local_updates_accumulated = 0
|
|
|
self.collaboration_state.register_step(current_step + 1)
|
|
|
self.averager.local_step = current_step + 1
|
|
|
self.collaboration_state_updated.set()
|
|
|
self.update_scheduler()
|
|
|
+ if grad_scaler is not None:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.update()
|
|
|
|
|
|
logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
|
|
|
@@ -344,38 +371,36 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
"""local gradient accumulators"""
|
|
|
if self.reuse_grad_buffers:
|
|
|
yield from self._grad_buffers()
|
|
|
- elif self._grads is None:
|
|
|
- with torch.no_grad():
|
|
|
- self._grads = [
|
|
|
- torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
|
|
|
- ]
|
|
|
+ return
|
|
|
+
|
|
|
+ if self._grads is None:
|
|
|
+ self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
|
|
|
yield from self._grads
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def accumulate_grads_(self, batch_size: int):
|
|
|
"""add current gradients to grad accumulators (if any)"""
|
|
|
if self.reuse_grad_buffers:
|
|
|
- return # user is responsible for accumulating gradients in .grad buffers
|
|
|
- alpha = float(batch_size) / self.batch_size_per_step
|
|
|
- for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
- grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
|
|
|
+ # user is responsible for accumulating gradients in .grad buffers
|
|
|
+ assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
|
|
|
+ else:
|
|
|
+ alpha = float(batch_size) / self.batch_size_per_step
|
|
|
+ for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
+ grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
|
|
|
- if self.reuse_grad_buffers:
|
|
|
- return
|
|
|
- for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
- grad_buf[...] = grad_acc.to(grad_buf.device)
|
|
|
- if scale_by is not None:
|
|
|
+ if not self.reuse_grad_buffers:
|
|
|
+ for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
+ grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
|
|
|
+ if scale_by is not None:
|
|
|
+ for grad_buf in self._grad_buffers():
|
|
|
grad_buf.mul_(scale_by)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def reset_accumulated_grads_(self):
|
|
|
- if self.reuse_grad_buffers:
|
|
|
- self.opt.zero_grad()
|
|
|
- else:
|
|
|
- for grad_buf in self.accumulated_grads():
|
|
|
- grad_buf.zero_()
|
|
|
+ for grad_buf in self.accumulated_grads():
|
|
|
+ grad_buf.zero_()
|
|
|
|
|
|
def report_training_progress(self):
|
|
|
"""Periodically publish metadata and the current number of samples accumulated towards the next step"""
|