|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
import warnings
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
from threading import Thread, Lock, Event
|
|
from threading import Thread, Lock, Event
|
|
-from typing import Optional, Type
|
|
|
|
|
|
+from typing import Optional, Iterator
|
|
import logging
|
|
import logging
|
|
|
|
|
|
import torch
|
|
import torch
|
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
from hivemind.dht import DHT
|
|
from hivemind.dht import DHT
|
|
from hivemind.client.optim.base import DecentralizedOptimizerBase
|
|
from hivemind.client.optim.base import DecentralizedOptimizerBase
|
|
from hivemind.client.averaging.training import TrainingAverager
|
|
from hivemind.client.averaging.training import TrainingAverager
|
|
-from hivemind.utils import get_logger, get_dht_time, run_in_background, ValueWithExpiration
|
|
|
|
|
|
+from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
|
|
from hivemind.client.optim.performance_ema import PerformanceEMA
|
|
from hivemind.client.optim.performance_ema import PerformanceEMA
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
@@ -47,7 +47,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
:note: This optimizer behaves unlike regular pytorch optimizers in two ways:
|
|
:note: This optimizer behaves unlike regular pytorch optimizers in two ways:
|
|
|
|
|
|
- - calling .step will periodially zero-out gradients w.r.t. model parameters after each step
|
|
|
|
|
|
+ - calling .step will periodically zero-out gradients w.r.t. model parameters after each step
|
|
- it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
|
|
- it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
|
|
|
|
|
|
:param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
|
|
:param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
|
|
@@ -55,7 +55,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
:param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
|
|
:param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
|
|
:param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
|
|
:param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
|
|
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
- :param target_group_size: maximum group size for DecentralizedAverager's all-reduce
|
|
|
|
:param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
|
|
:param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
|
|
:param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
|
|
:param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
|
|
:param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
|
|
:param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
|
|
@@ -69,6 +68,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
|
|
:param scheduler: if specified, use this scheduler to update optimizer learning rate
|
|
:param scheduler: if specified, use this scheduler to update optimizer learning rate
|
|
|
|
+ :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
|
|
+ This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
|
|
+ :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
|
|
|
|
+ device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
|
|
|
|
+ the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
|
|
|
|
+ :param kwargs: additional parameters forwarded to DecentralizedAverager
|
|
:note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
|
|
:note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
|
|
explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
|
|
explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
|
|
"""
|
|
"""
|
|
@@ -78,14 +83,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
|
|
min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
|
|
expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
|
|
expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
|
|
metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
|
|
metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
|
|
- **kwargs):
|
|
|
|
|
|
+ reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None, **kwargs):
|
|
super().__init__(opt, dht)
|
|
super().__init__(opt, dht)
|
|
|
|
+ if reuse_grad_buffers and accumulate_grads_on is not None:
|
|
|
|
+ logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
|
|
self.prefix, self.scheduler = prefix, scheduler
|
|
self.prefix, self.scheduler = prefix, scheduler
|
|
self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
|
|
self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
|
|
self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
|
|
self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
|
|
min_refresh_period, max_refresh_period, default_refresh_period
|
|
min_refresh_period, max_refresh_period, default_refresh_period
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
|
|
self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
|
|
|
|
+ self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
self.averager = self._make_averager(**kwargs)
|
|
self.averager = self._make_averager(**kwargs)
|
|
|
|
|
|
@@ -134,9 +142,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
:param batch_size: optional override for batch_size_per_step from init
|
|
:param batch_size: optional override for batch_size_per_step from init
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
"""
|
|
"""
|
|
- if batch_size is not None and self.batch_size_per_step is None:
|
|
|
|
- raise ValueError("Please either set batch_size_per_step parameter at init or provide batch_size in .step")
|
|
|
|
- batch_size = self.batch_size_per_step if batch_size is None else batch_size
|
|
|
|
|
|
+ if self.batch_size_per_step is None:
|
|
|
|
+ if batch_size is None:
|
|
|
|
+ raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
|
|
|
|
+ logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
|
|
|
|
+ self.batch_size_per_step = batch_size
|
|
|
|
+ batch_size = batch_size if batch_size is not None else self.batch_size_per_step
|
|
|
|
|
|
if not self.is_synchronized:
|
|
if not self.is_synchronized:
|
|
self.load_state_from_peers()
|
|
self.load_state_from_peers()
|
|
@@ -146,6 +157,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
|
|
logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
|
|
f"but metadata expired in {self.metadata_expiration} s.")
|
|
f"but metadata expired in {self.metadata_expiration} s.")
|
|
|
|
|
|
|
|
+ self.accumulate_grads_(batch_size)
|
|
with self.lock_local_progress:
|
|
with self.lock_local_progress:
|
|
self.local_samples_accumulated += batch_size
|
|
self.local_samples_accumulated += batch_size
|
|
self.local_steps_accumulated += 1
|
|
self.local_steps_accumulated += 1
|
|
@@ -164,6 +176,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
return
|
|
return
|
|
|
|
|
|
with self.performance_ema.pause(), self.lock_collaboration_state:
|
|
with self.performance_ema.pause(), self.lock_collaboration_state:
|
|
|
|
+ # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
|
+ self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
|
|
|
|
+
|
|
if self.collaboration_state.num_peers > 1:
|
|
if self.collaboration_state.num_peers > 1:
|
|
mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
|
|
mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
|
|
weight = self.local_samples_accumulated / mean_samples_per_worker
|
|
weight = self.local_samples_accumulated / mean_samples_per_worker
|
|
@@ -176,6 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
self.opt.step()
|
|
self.opt.step()
|
|
self.opt.zero_grad()
|
|
self.opt.zero_grad()
|
|
|
|
+ self.reset_accumulated_grads_()
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.collaboration_state.register_step()
|
|
self.collaboration_state.register_step()
|
|
self.collaboration_state_updated.set()
|
|
self.collaboration_state_updated.set()
|
|
@@ -184,6 +200,46 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
logger.log(self.status_loglevel, f"Optimizer step: done!")
|
|
return output
|
|
return output
|
|
|
|
|
|
|
|
+ def _grad_buffers(self) -> Iterator[torch.Tensor]:
|
|
|
|
+ """ pytorch-internal gradient buffers """
|
|
|
|
+ for param_group in self.opt.param_groups:
|
|
|
|
+ for param in param_group['params']:
|
|
|
|
+ if param.grad is None:
|
|
|
|
+ yield torch.zeros_like(param)
|
|
|
|
+ else:
|
|
|
|
+ yield param.grad
|
|
|
|
+
|
|
|
|
+ @torch.no_grad()
|
|
|
|
+ def accumulated_grads(self) -> Iterator[torch.Tensor]:
|
|
|
|
+ """ local gradient accumulators """
|
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
|
+ yield from self._grad_buffers()
|
|
|
|
+ elif self._grads is None:
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
|
|
|
|
+ yield from self._grads
|
|
|
|
+
|
|
|
|
+ @torch.no_grad()
|
|
|
|
+ def accumulate_grads_(self, batch_size: int):
|
|
|
|
+ """ add current gradients to grad accumulators (if any) """
|
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
|
+ return # user is responsible for accumulating gradients in .grad buffers
|
|
|
|
+ alpha = float(batch_size) / self.batch_size_per_step
|
|
|
|
+ for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
|
+ grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
|
|
|
|
+
|
|
|
|
+ @torch.no_grad()
|
|
|
|
+ def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
|
|
|
|
+ for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
|
+ grad_buf[...] = grad_acc.to(grad_buf.device)
|
|
|
|
+ if scale_by is not None:
|
|
|
|
+ grad_buf.mul_(scale_by)
|
|
|
|
+
|
|
|
|
+ @torch.no_grad()
|
|
|
|
+ def reset_accumulated_grads_(self):
|
|
|
|
+ for grad_buf in self._grad_buffers():
|
|
|
|
+ grad_buf.zero_()
|
|
|
|
+
|
|
def report_training_progress(self):
|
|
def report_training_progress(self):
|
|
""" Periodically publish metadata and the current number of samples accumulated towards the next step """
|
|
""" Periodically publish metadata and the current number of samples accumulated towards the next step """
|
|
while self.is_alive():
|
|
while self.is_alive():
|
|
@@ -235,17 +291,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
if not is_client:
|
|
if not is_client:
|
|
global_optimizer_step = max(global_optimizer_step, opt_step)
|
|
global_optimizer_step = max(global_optimizer_step, opt_step)
|
|
|
|
|
|
- total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0
|
|
|
|
|
|
+ total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
|
|
|
|
|
|
for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
|
|
for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
|
|
total_samples_per_second += samples_per_second
|
|
total_samples_per_second += samples_per_second
|
|
if opt_step == global_optimizer_step:
|
|
if opt_step == global_optimizer_step:
|
|
total_samples_accumulated += samples_accumulated
|
|
total_samples_accumulated += samples_accumulated
|
|
- estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
|
|
|
|
|
|
+ estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
|
|
# note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
|
|
# note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
|
|
# the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
|
|
# the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
|
|
|
|
|
|
- estimated_samples_remaining = self.target_batch_size - estimated_curent_samples
|
|
|
|
|
|
+ estimated_samples_remaining = self.target_batch_size - estimated_current_samples
|
|
estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
|
|
estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
|
|
|
|
|
|
expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
|
|
expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
|