collaborative.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass
  4. from threading import Thread, Lock, Event
  5. from typing import Dict, Optional, Iterator
  6. import numpy as np
  7. import torch
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.averaging.training import TrainingAverager
  10. from hivemind.dht import DHT
  11. from hivemind.dht.crypto import RSASignatureValidator
  12. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  13. from hivemind.optim.base import DecentralizedOptimizerBase
  14. from hivemind.optim.performance_ema import PerformanceEMA
  15. from hivemind.utils import Endpoint, get_dht_time, get_logger
  16. logger = get_logger(__name__)
  17. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  18. @dataclass(frozen=False)
  19. class CollaborationState:
  20. optimizer_step: int
  21. samples_accumulated: int
  22. target_batch_size: int
  23. num_peers: int
  24. num_clients: int
  25. eta_next_step: float
  26. next_fetch_time: float
  27. @property
  28. def ready_for_step(self):
  29. return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
  30. def register_step(self, local_step: int):
  31. self.optimizer_step = max(local_step, self.optimizer_step)
  32. self.samples_accumulated = 0
  33. self.eta_next_step = float("inf")
  34. class TrainingState(BaseModel):
  35. peer_id: bytes
  36. step: conint(ge=0, strict=True)
  37. samples_accumulated: conint(ge=0, strict=True)
  38. samples_per_second: confloat(ge=0.0, strict=True)
  39. time: StrictFloat
  40. client_mode: StrictBool
  41. class TrainingProgressSchema(BaseModel):
  42. progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
  43. class CollaborativeOptimizer(DecentralizedOptimizerBase):
  44. """
  45. An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
  46. These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
  47. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
  48. :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
  49. * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
  50. * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
  51. :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  52. :param dht: a running hivemind.DHT daemon connected to other peers
  53. :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
  54. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  55. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  56. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  57. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  58. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  59. :param expected_drift_peers: assume that this many new peers can join between steps
  60. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
  61. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  62. refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
  63. :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
  64. :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
  65. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
  66. :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
  67. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  68. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  69. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  70. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  71. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  72. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  73. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  74. the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
  75. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  76. :param kwargs: additional parameters forwarded to DecentralizedAverager
  77. :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  78. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  79. """
  80. def __init__(
  81. self,
  82. opt: torch.optim.Optimizer,
  83. *,
  84. dht: DHT,
  85. prefix: str,
  86. target_batch_size: int,
  87. batch_size_per_step: Optional[int] = None,
  88. scheduler: Optional[LRSchedulerBase] = None,
  89. min_refresh_period: float = 0.5,
  90. max_refresh_period: float = 30,
  91. default_refresh_period: float = 3,
  92. expected_drift_peers: float = 3,
  93. expected_drift_rate: float = 0.2,
  94. performance_ema_alpha: float = 0.1,
  95. metadata_expiration: float = 60.0,
  96. averaging_timeout: Optional[float] = None,
  97. step_tolerance: int = 1,
  98. reuse_grad_buffers: bool = False,
  99. accumulate_grads_on: Optional[torch.device] = None,
  100. client_mode: bool = False,
  101. verbose: bool = False,
  102. **kwargs,
  103. ):
  104. super().__init__(opt, dht)
  105. signature_validator = RSASignatureValidator()
  106. self._local_public_key = signature_validator.local_public_key
  107. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  108. if reuse_grad_buffers and accumulate_grads_on is not None:
  109. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  110. self.prefix, self.scheduler = prefix, scheduler
  111. self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
  112. self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
  113. min_refresh_period,
  114. max_refresh_period,
  115. default_refresh_period,
  116. )
  117. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  118. self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
  119. self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
  120. self.client_mode, self.step_tolerance = client_mode, step_tolerance
  121. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  122. self.averager = self._make_averager(**kwargs)
  123. self.training_progress_key = f"{self.prefix}_progress"
  124. self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
  125. self.local_steps_accumulated = 0 # a number of calls to step() since last optimizer update
  126. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  127. self.last_step_time = None
  128. self.collaboration_state = self.fetch_collaboration_state()
  129. self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
  130. self.lock_local_progress, self.should_report_progress = Lock(), Event()
  131. self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
  132. self.progress_reporter.start()
  133. self.collaboration_state_updater = Thread(
  134. target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
  135. )
  136. self.collaboration_state_updater.start()
  137. def _make_averager(self, **kwargs):
  138. return TrainingAverager(
  139. self.opt,
  140. dht=self.dht,
  141. average_parameters=True,
  142. average_gradients=True,
  143. prefix=f"{self.prefix}_averaging",
  144. allreduce_timeout=self.averaging_timeout,
  145. client_mode=self.client_mode,
  146. **kwargs,
  147. )
  148. @property
  149. def local_step(self) -> int:
  150. return self.averager.local_step
  151. @property
  152. def is_synchronized(self) -> bool:
  153. return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
  154. def is_alive(self) -> bool:
  155. return self.averager.is_alive()
  156. def load_state_from_peers(self, **kwargs):
  157. """Attempt to fetch the newest collaboration state from other peers"""
  158. with self.lock_collaboration_state:
  159. self.averager.load_state_from_peers(**kwargs)
  160. self.local_samples_accumulated = self.local_steps_accumulated = 0
  161. self.reset_accumulated_grads_()
  162. self.update_scheduler()
  163. def step(self, batch_size: Optional[int] = None, **kwargs):
  164. """
  165. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  166. :param batch_size: optional override for batch_size_per_step from init
  167. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  168. """
  169. if self.batch_size_per_step is None:
  170. if batch_size is None:
  171. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  172. logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
  173. self.batch_size_per_step = batch_size
  174. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  175. if not self.is_synchronized:
  176. logger.log(self.status_loglevel, "Peer is out of sync.")
  177. self.load_state_from_peers()
  178. return
  179. if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
  180. logger.warning(
  181. f"Training step took {get_dht_time() - self.last_step_time}, "
  182. f"but metadata expired in {self.metadata_expiration} s."
  183. )
  184. self.accumulate_grads_(batch_size)
  185. with self.lock_local_progress:
  186. self.local_samples_accumulated += batch_size
  187. self.local_steps_accumulated += 1
  188. self.performance_ema.update(num_processed=batch_size)
  189. self.should_report_progress.set()
  190. if not self.collaboration_state.ready_for_step:
  191. return
  192. logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
  193. self.collaboration_state = self.fetch_collaboration_state()
  194. self.collaboration_state_updated.set()
  195. if not self.is_synchronized:
  196. self.load_state_from_peers()
  197. return
  198. with self.performance_ema.pause(), self.lock_collaboration_state:
  199. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  200. self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
  201. current_step, group_info = self.averager.local_step, None
  202. if self.collaboration_state.num_peers > 1:
  203. mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
  204. weight = self.local_samples_accumulated / mean_samples_per_worker
  205. try:
  206. group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
  207. if group_info:
  208. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  209. except BaseException as e:
  210. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  211. else:
  212. logger.log(
  213. self.status_loglevel,
  214. f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
  215. )
  216. self.opt.step()
  217. self.reset_accumulated_grads_()
  218. self.local_samples_accumulated = self.local_steps_accumulated = 0
  219. self.collaboration_state.register_step(current_step + 1)
  220. self.averager.local_step = current_step + 1
  221. self.collaboration_state_updated.set()
  222. self.update_scheduler()
  223. logger.log(self.status_loglevel, f"Optimizer step: done!")
  224. return group_info
  225. def step_aux(self, **kwargs):
  226. """
  227. Find and assist other peers in averaging without sending local gradients.
  228. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  229. """
  230. if not self.collaboration_state.ready_for_step:
  231. return
  232. logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
  233. self.collaboration_state = self.fetch_collaboration_state()
  234. self.collaboration_state_updated.set()
  235. with self.lock_collaboration_state:
  236. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  237. current_step, group_info = self.averager.local_step, None
  238. try:
  239. group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
  240. if group_info:
  241. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  242. except BaseException as e:
  243. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  244. self.collaboration_state.register_step(current_step + 1)
  245. self.averager.local_step = current_step + 1
  246. self.collaboration_state_updated.set()
  247. logger.log(self.status_loglevel, f"Optimizer step: done!")
  248. return group_info
  249. def _grad_buffers(self) -> Iterator[torch.Tensor]:
  250. """pytorch-internal gradient buffers"""
  251. for param_group in self.opt.param_groups:
  252. for param in param_group["params"]:
  253. if param.grad is None:
  254. yield torch.zeros_like(param)
  255. else:
  256. yield param.grad
  257. @torch.no_grad()
  258. def accumulated_grads(self) -> Iterator[torch.Tensor]:
  259. """local gradient accumulators"""
  260. if self.reuse_grad_buffers:
  261. yield from self._grad_buffers()
  262. elif self._grads is None:
  263. with torch.no_grad():
  264. self._grads = [
  265. torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
  266. ]
  267. yield from self._grads
  268. @torch.no_grad()
  269. def accumulate_grads_(self, batch_size: int):
  270. """add current gradients to grad accumulators (if any)"""
  271. if self.reuse_grad_buffers:
  272. return # user is responsible for accumulating gradients in .grad buffers
  273. alpha = float(batch_size) / self.batch_size_per_step
  274. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  275. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  276. @torch.no_grad()
  277. def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
  278. if self.reuse_grad_buffers:
  279. return
  280. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  281. grad_buf[...] = grad_acc.to(grad_buf.device)
  282. if scale_by is not None:
  283. grad_buf.mul_(scale_by)
  284. @torch.no_grad()
  285. def reset_accumulated_grads_(self):
  286. if self.reuse_grad_buffers:
  287. self.opt.zero_grad()
  288. else:
  289. for grad_buf in self.accumulated_grads():
  290. grad_buf.zero_()
  291. def report_training_progress(self):
  292. """Periodically publish metadata and the current number of samples accumulated towards the next step"""
  293. while self.is_alive():
  294. self.should_report_progress.wait()
  295. self.should_report_progress.clear()
  296. with self.lock_local_progress:
  297. current_time = get_dht_time()
  298. local_state_info = TrainingState(
  299. peer_id=self.averager.peer_id.to_bytes(),
  300. step=self.local_step,
  301. samples_accumulated=self.local_samples_accumulated,
  302. samples_per_second=self.performance_ema.samples_per_second,
  303. time=current_time,
  304. client_mode=self.averager.client_mode,
  305. )
  306. self.dht.store(
  307. key=self.training_progress_key,
  308. subkey=self._local_public_key,
  309. value=local_state_info.dict(),
  310. expiration_time=current_time + self.metadata_expiration,
  311. return_future=True,
  312. )
  313. def check_collaboration_state_periodically(self):
  314. """
  315. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  316. """
  317. while self.is_alive():
  318. time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
  319. if self.collaboration_state_updated.wait(time_to_next_update):
  320. self.collaboration_state_updated.clear()
  321. continue # if state was updated externally, reset timer
  322. with self.lock_collaboration_state:
  323. self.collaboration_state = self.fetch_collaboration_state()
  324. def fetch_collaboration_state(self) -> CollaborationState:
  325. """Read performance statistics reported by peers, estimate progress towards next batch"""
  326. response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  327. current_time = get_dht_time()
  328. if not isinstance(response, dict) or len(response) == 0:
  329. logger.log(self.status_loglevel, f"Found no active peers: {response}")
  330. local_eta_next_step = (
  331. max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
  332. )
  333. return CollaborationState(
  334. self.local_step,
  335. self.local_samples_accumulated,
  336. self.target_batch_size,
  337. num_peers=0,
  338. num_clients=0,
  339. eta_next_step=current_time + local_eta_next_step,
  340. next_fetch_time=current_time + self.default_refresh_period,
  341. )
  342. valid_peer_states = [
  343. TrainingState.parse_obj(peer_state.value)
  344. for peer_state in response.values()
  345. if peer_state.value is not None
  346. ]
  347. num_peers = len(valid_peer_states)
  348. num_clients = sum(state.client_mode for state in valid_peer_states)
  349. global_optimizer_step = self.local_step
  350. for state in valid_peer_states:
  351. if not state.client_mode:
  352. global_optimizer_step = max(global_optimizer_step, state.step)
  353. total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
  354. for state in valid_peer_states:
  355. total_samples_per_second += state.samples_per_second
  356. if state.step == global_optimizer_step:
  357. total_samples_accumulated += state.samples_accumulated
  358. estimated_current_samples += (
  359. state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
  360. )
  361. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  362. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  363. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  364. estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
  365. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  366. time_to_next_fetch = float(
  367. np.clip(
  368. a=estimated_time_to_next_step * num_peers / expected_max_peers,
  369. a_min=self.min_refresh_period,
  370. a_max=self.max_refresh_period,
  371. )
  372. )
  373. logger.log(
  374. self.status_loglevel,
  375. f"Collaboration accumulated {total_samples_accumulated} samples from "
  376. f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
  377. f"(refresh in {time_to_next_fetch:.2f}s.)",
  378. )
  379. return CollaborationState(
  380. global_optimizer_step,
  381. total_samples_accumulated,
  382. target_batch_size=self.target_batch_size,
  383. num_peers=num_peers,
  384. num_clients=num_clients,
  385. eta_next_step=current_time + estimated_time_to_next_step,
  386. next_fetch_time=current_time + time_to_next_fetch,
  387. )
  388. def zero_grad(self, *args, **kwargs):
  389. if self.reuse_grad_buffers:
  390. raise ValueError(
  391. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  392. f"call zero_grad manually. Gradients will be refreshed internally."
  393. )
  394. return self.opt.zero_grad(*args, **kwargs)
  395. def update_scheduler(self):
  396. if self.scheduler:
  397. while self.scheduler._step_count < self.local_step:
  398. self.scheduler.step()
  399. def shutdown(self):
  400. logger.debug("Shutting down averager...")
  401. self.averager.shutdown()
  402. logger.debug("Sending goodbye to peers...")
  403. self.dht.store(
  404. self.training_progress_key,
  405. subkey=self._local_public_key,
  406. value=None,
  407. expiration_time=get_dht_time() + self.metadata_expiration,
  408. )
  409. logger.debug(f"{self.__class__.__name__} is shut down.")
  410. def __del__(self):
  411. self.shutdown()