collaborative.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass
  4. from threading import Thread, Lock, Event
  5. from typing import Dict, Optional, Iterator
  6. import numpy as np
  7. import torch
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.averaging.training import TrainingAverager
  10. from hivemind.dht import DHT
  11. from hivemind.dht.crypto import RSASignatureValidator
  12. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  13. from hivemind.optim.base import DecentralizedOptimizerBase
  14. from hivemind.optim.performance_ema import PerformanceEMA
  15. from hivemind.utils import Endpoint, get_dht_time, get_logger
  16. logger = get_logger(__name__)
  17. LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
  18. @dataclass(frozen=False)
  19. class CollaborationState:
  20. optimizer_step: int
  21. samples_accumulated: int
  22. target_batch_size: int
  23. num_peers: int
  24. num_clients: int
  25. eta_next_step: float
  26. next_fetch_time: float
  27. @property
  28. def ready_for_step(self):
  29. return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
  30. def register_step(self, local_step: int):
  31. self.optimizer_step = max(local_step, self.optimizer_step)
  32. self.samples_accumulated = 0
  33. self.eta_next_step = float('inf')
  34. class TrainingState(BaseModel):
  35. endpoint: Endpoint
  36. step: conint(ge=0, strict=True)
  37. samples_accumulated: conint(ge=0, strict=True)
  38. samples_per_second: confloat(ge=0.0, strict=True)
  39. time: StrictFloat
  40. client_mode: StrictBool
  41. class TrainingProgressSchema(BaseModel):
  42. progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
  43. class CollaborativeOptimizer(DecentralizedOptimizerBase):
  44. """
  45. An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
  46. These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
  47. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
  48. :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
  49. - calling .step will periodically zero-out gradients w.r.t. model parameters after each step
  50. - it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
  51. :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  52. :param dht: a running hivemind.DHT daemon connected to other peers
  53. :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
  54. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  55. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  56. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  57. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  58. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  59. :param expected_drift_peers: assume that this many new peers can join between steps
  60. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
  61. :note: the expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  62. refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
  63. :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
  64. :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
  65. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
  66. :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
  67. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  68. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  69. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  70. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  71. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  72. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  73. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  74. the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
  75. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  76. :param kwargs: additional parameters forwarded to DecentralizedAverager
  77. :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
  78. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  79. """
  80. def __init__(self, opt: torch.optim.Optimizer, *, dht: DHT, prefix: str, target_batch_size: int,
  81. batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
  82. min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
  83. expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
  84. metadata_expiration: float = 60.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
  85. reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
  86. client_mode: bool = False, verbose: bool = False, **kwargs):
  87. super().__init__(opt, dht)
  88. signature_validator = RSASignatureValidator()
  89. self._local_public_key = signature_validator.local_public_key
  90. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix),
  91. signature_validator])
  92. if reuse_grad_buffers and accumulate_grads_on is not None:
  93. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  94. self.prefix, self.scheduler = prefix, scheduler
  95. self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
  96. self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = \
  97. min_refresh_period, max_refresh_period, default_refresh_period
  98. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  99. self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
  100. self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
  101. self.client_mode, self.step_tolerance = client_mode, step_tolerance
  102. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  103. self.averager = self._make_averager(**kwargs)
  104. self.training_progress_key = f"{self.prefix}_progress"
  105. self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
  106. self.local_steps_accumulated = 0 # a number of calls to step() since last optimizer update
  107. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  108. self.last_step_time = None
  109. self.collaboration_state = self.fetch_collaboration_state()
  110. self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
  111. self.lock_local_progress, self.should_report_progress = Lock(), Event()
  112. self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
  113. self.progress_reporter.start()
  114. self.collaboration_state_updater = Thread(target=self.check_collaboration_state_periodically, daemon=True,
  115. name=f"{self}.collaboration_state_updater")
  116. self.collaboration_state_updater.start()
  117. def _make_averager(self, **kwargs):
  118. return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=True,
  119. prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
  120. listen=not self.client_mode, **kwargs)
  121. @property
  122. def local_step(self) -> int:
  123. return self.averager.local_step
  124. @property
  125. def is_synchronized(self) -> bool:
  126. return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
  127. def is_alive(self) -> bool:
  128. return self.averager.is_alive()
  129. def load_state_from_peers(self, **kwargs):
  130. """ Attempt to fetch the newest collaboration state from other peers """
  131. with self.lock_collaboration_state:
  132. self.averager.load_state_from_peers(**kwargs)
  133. self.local_samples_accumulated = self.local_steps_accumulated = 0
  134. self.reset_accumulated_grads_()
  135. self.update_scheduler()
  136. def step(self, batch_size: Optional[int] = None, **kwargs):
  137. """
  138. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  139. :param batch_size: optional override for batch_size_per_step from init
  140. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  141. """
  142. if self.batch_size_per_step is None:
  143. if batch_size is None:
  144. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  145. logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
  146. self.batch_size_per_step = batch_size
  147. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  148. if not self.is_synchronized:
  149. logger.log(self.status_loglevel, "Peer is out of sync.")
  150. self.load_state_from_peers()
  151. return
  152. if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
  153. logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
  154. f"but metadata expired in {self.metadata_expiration} s.")
  155. self.accumulate_grads_(batch_size)
  156. with self.lock_local_progress:
  157. self.local_samples_accumulated += batch_size
  158. self.local_steps_accumulated += 1
  159. self.performance_ema.update(num_processed=batch_size)
  160. self.should_report_progress.set()
  161. if not self.collaboration_state.ready_for_step:
  162. return
  163. logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
  164. self.collaboration_state = self.fetch_collaboration_state()
  165. self.collaboration_state_updated.set()
  166. if not self.is_synchronized:
  167. self.load_state_from_peers()
  168. return
  169. with self.performance_ema.pause(), self.lock_collaboration_state:
  170. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  171. self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
  172. current_step, group_info = self.averager.local_step, None
  173. if self.collaboration_state.num_peers > 1:
  174. mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
  175. weight = self.local_samples_accumulated / mean_samples_per_worker
  176. try:
  177. group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
  178. if group_info:
  179. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  180. except BaseException as e:
  181. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  182. else:
  183. logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
  184. f"{self.collaboration_state.num_peers} peer(s).")
  185. self.opt.step()
  186. self.reset_accumulated_grads_()
  187. self.local_samples_accumulated = self.local_steps_accumulated = 0
  188. self.collaboration_state.register_step(current_step + 1)
  189. self.averager.local_step = current_step + 1
  190. self.collaboration_state_updated.set()
  191. self.update_scheduler()
  192. logger.log(self.status_loglevel, f"Optimizer step: done!")
  193. return group_info
  194. def step_aux(self, **kwargs):
  195. """
  196. Find and assist other peers in averaging without sending local gradients.
  197. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  198. """
  199. if not self.collaboration_state.ready_for_step:
  200. return
  201. logger.log(self.status_loglevel,
  202. f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
  203. self.collaboration_state = self.fetch_collaboration_state()
  204. self.collaboration_state_updated.set()
  205. with self.lock_collaboration_state:
  206. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  207. current_step, group_info = self.averager.local_step, None
  208. try:
  209. group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
  210. if group_info:
  211. logger.log(self.status_loglevel,
  212. f"Averaged tensors successfully with {len(group_info)} peers")
  213. except BaseException as e:
  214. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  215. self.collaboration_state.register_step(current_step + 1)
  216. self.averager.local_step = current_step + 1
  217. self.collaboration_state_updated.set()
  218. logger.log(self.status_loglevel, f"Optimizer step: done!")
  219. return group_info
  220. def _grad_buffers(self) -> Iterator[torch.Tensor]:
  221. """ pytorch-internal gradient buffers """
  222. for param_group in self.opt.param_groups:
  223. for param in param_group['params']:
  224. if param.grad is None:
  225. yield torch.zeros_like(param)
  226. else:
  227. yield param.grad
  228. @torch.no_grad()
  229. def accumulated_grads(self) -> Iterator[torch.Tensor]:
  230. """ local gradient accumulators """
  231. if self.reuse_grad_buffers:
  232. yield from self._grad_buffers()
  233. elif self._grads is None:
  234. with torch.no_grad():
  235. self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
  236. yield from self._grads
  237. @torch.no_grad()
  238. def accumulate_grads_(self, batch_size: int):
  239. """ add current gradients to grad accumulators (if any) """
  240. if self.reuse_grad_buffers:
  241. return # user is responsible for accumulating gradients in .grad buffers
  242. alpha = float(batch_size) / self.batch_size_per_step
  243. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  244. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  245. @torch.no_grad()
  246. def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
  247. if self.reuse_grad_buffers:
  248. return
  249. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  250. grad_buf[...] = grad_acc.to(grad_buf.device)
  251. if scale_by is not None:
  252. grad_buf.mul_(scale_by)
  253. @torch.no_grad()
  254. def reset_accumulated_grads_(self):
  255. if self.reuse_grad_buffers:
  256. self.opt.zero_grad()
  257. else:
  258. for grad_buf in self.accumulated_grads():
  259. grad_buf.zero_()
  260. def report_training_progress(self):
  261. """ Periodically publish metadata and the current number of samples accumulated towards the next step """
  262. while self.is_alive():
  263. self.should_report_progress.wait()
  264. self.should_report_progress.clear()
  265. with self.lock_local_progress:
  266. current_time = get_dht_time()
  267. local_state_info = TrainingState(
  268. endpoint=self.averager.endpoint,
  269. step=self.local_step,
  270. samples_accumulated=self.local_samples_accumulated,
  271. samples_per_second=self.performance_ema.samples_per_second,
  272. time=current_time,
  273. client_mode=not self.averager.listen)
  274. self.dht.store(key=self.training_progress_key, subkey=self._local_public_key,
  275. value=local_state_info.dict(),
  276. expiration_time=current_time + self.metadata_expiration,
  277. return_future=True)
  278. def check_collaboration_state_periodically(self):
  279. """
  280. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  281. """
  282. while self.is_alive():
  283. time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
  284. if self.collaboration_state_updated.wait(time_to_next_update):
  285. self.collaboration_state_updated.clear()
  286. continue # if state was updated externally, reset timer
  287. with self.lock_collaboration_state:
  288. self.collaboration_state = self.fetch_collaboration_state()
  289. def fetch_collaboration_state(self) -> CollaborationState:
  290. """ Read performance statistics reported by peers, estimate progress towards next batch """
  291. response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float('inf'))
  292. current_time = get_dht_time()
  293. if not isinstance(response, dict) or len(response) == 0:
  294. logger.log(self.status_loglevel, f"Found no active peers: {response}")
  295. local_eta_next_step = max(0, self.target_batch_size - self.local_steps_accumulated
  296. ) / self.performance_ema.samples_per_second
  297. return CollaborationState(self.local_step, self.local_samples_accumulated, self.target_batch_size,
  298. num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
  299. next_fetch_time=current_time + self.default_refresh_period)
  300. valid_peer_states = [TrainingState.parse_obj(peer_state.value)
  301. for peer_state in response.values()
  302. if peer_state.value is not None]
  303. num_peers = len(valid_peer_states)
  304. num_clients = sum(state.client_mode for state in valid_peer_states)
  305. global_optimizer_step = self.local_step
  306. for state in valid_peer_states:
  307. if not state.client_mode:
  308. global_optimizer_step = max(global_optimizer_step, state.step)
  309. total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
  310. for state in valid_peer_states:
  311. total_samples_per_second += state.samples_per_second
  312. if state.step == global_optimizer_step:
  313. total_samples_accumulated += state.samples_accumulated
  314. estimated_current_samples += (state.samples_accumulated +
  315. max(0, current_time - state.time) * state.samples_per_second)
  316. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  317. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  318. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  319. estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
  320. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  321. time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers,
  322. a_min=self.min_refresh_period, a_max=self.max_refresh_period))
  323. logger.log(self.status_loglevel, f"Collaboration accumulated {total_samples_accumulated} samples from "
  324. f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
  325. f"(refresh in {time_to_next_fetch:.2f}s.)")
  326. return CollaborationState(
  327. global_optimizer_step, total_samples_accumulated, target_batch_size=self.target_batch_size,
  328. num_peers=num_peers, num_clients=num_clients, eta_next_step=current_time + estimated_time_to_next_step,
  329. next_fetch_time=current_time + time_to_next_fetch)
  330. def zero_grad(self, *args, **kwargs):
  331. if self.reuse_grad_buffers:
  332. raise ValueError(f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  333. f"call zero_grad manually. Gradients will be refreshed internally.")
  334. return self.opt.zero_grad(*args, **kwargs)
  335. def update_scheduler(self):
  336. if self.scheduler:
  337. while self.scheduler._step_count < self.local_step:
  338. self.scheduler.step()
  339. def shutdown(self):
  340. logger.debug("Shutting down averager...")
  341. self.averager.shutdown()
  342. logger.debug("Sending goodbye to peers...")
  343. self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None,
  344. expiration_time=get_dht_time() + self.metadata_expiration)
  345. logger.debug(f"{self.__class__.__name__} is shut down.")
  346. def __del__(self):
  347. self.shutdown()