collaborative.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass
  4. from threading import Event, Lock, Thread
  5. from typing import Dict, Iterator, Optional
  6. import numpy as np
  7. import torch
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.averaging.training import TrainingAverager
  10. from hivemind.dht import DHT
  11. from hivemind.dht.crypto import RSASignatureValidator
  12. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  13. from hivemind.optim.base import DecentralizedOptimizerBase
  14. from hivemind.optim.performance_ema import PerformanceEMA
  15. from hivemind.utils import Endpoint, get_dht_time, get_logger
  16. logger = get_logger(__name__)
  17. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  18. @dataclass(frozen=False)
  19. class CollaborationState:
  20. optimizer_step: int
  21. samples_accumulated: int
  22. target_batch_size: int
  23. num_peers: int
  24. num_clients: int
  25. eta_next_step: float
  26. next_fetch_time: float
  27. @property
  28. def ready_for_step(self):
  29. return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
  30. def register_step(self, local_step: int):
  31. self.optimizer_step = max(local_step, self.optimizer_step)
  32. self.samples_accumulated = 0
  33. self.eta_next_step = float("inf")
  34. class TrainingState(BaseModel):
  35. peer_id: bytes
  36. step: conint(ge=0, strict=True)
  37. samples_accumulated: conint(ge=0, strict=True)
  38. samples_per_second: confloat(ge=0.0, strict=True)
  39. time: StrictFloat
  40. client_mode: StrictBool
  41. class TrainingProgressSchema(BaseModel):
  42. progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
  43. class CollaborativeOptimizer(DecentralizedOptimizerBase):
  44. """
  45. An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
  46. These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
  47. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
  48. :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
  49. * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
  50. * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
  51. :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  52. :param dht: a running hivemind.DHT daemon connected to other peers
  53. :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
  54. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  55. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  56. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  57. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  58. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  59. :param expected_drift_peers: assume that this many new peers can join between steps
  60. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
  61. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  62. refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
  63. :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
  64. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
  65. :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
  66. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  67. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  68. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  69. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  70. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  71. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  72. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  73. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  74. the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
  75. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  76. :param kwargs: additional parameters forwarded to DecentralizedAverager
  77. :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  78. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  79. """
  80. def __init__(
  81. self,
  82. opt: torch.optim.Optimizer,
  83. *,
  84. dht: DHT,
  85. prefix: str,
  86. target_batch_size: int,
  87. batch_size_per_step: Optional[int] = None,
  88. scheduler: Optional[LRSchedulerBase] = None,
  89. min_refresh_period: float = 0.5,
  90. max_refresh_period: float = 30,
  91. default_refresh_period: float = 3,
  92. expected_drift_peers: float = 3,
  93. expected_drift_rate: float = 0.2,
  94. performance_ema_alpha: float = 0.1,
  95. metadata_expiration: float = 60.0,
  96. averaging_timeout: Optional[float] = None,
  97. load_state_timeout: float = 600.0,
  98. reuse_grad_buffers: bool = False,
  99. accumulate_grads_on: Optional[torch.device] = None,
  100. client_mode: bool = False,
  101. verbose: bool = False,
  102. **kwargs,
  103. ):
  104. super().__init__(opt, dht)
  105. signature_validator = RSASignatureValidator()
  106. self._local_public_key = signature_validator.local_public_key
  107. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  108. if reuse_grad_buffers and accumulate_grads_on is not None:
  109. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  110. self.prefix, self.scheduler = prefix, scheduler
  111. self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
  112. self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
  113. min_refresh_period,
  114. max_refresh_period,
  115. default_refresh_period,
  116. )
  117. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  118. self.averaging_timeout = averaging_timeout
  119. self.load_state_timeout = load_state_timeout
  120. self.metadata_expiration = metadata_expiration
  121. self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
  122. self.client_mode = client_mode
  123. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  124. self.averager = self._make_averager(**kwargs)
  125. self.training_progress_key = f"{self.prefix}_progress"
  126. self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
  127. self.local_steps_accumulated = 0 # a number of calls to step() since last optimizer update
  128. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  129. self.last_step_time = None
  130. self.collaboration_state = self._fetch_state()
  131. self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
  132. self.lock_local_progress, self.should_report_progress = Lock(), Event()
  133. self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
  134. self.progress_reporter.start()
  135. self.collaboration_state_updater = Thread(
  136. target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
  137. )
  138. self.collaboration_state_updater.start()
  139. def _make_averager(self, **kwargs):
  140. return TrainingAverager(
  141. self.opt,
  142. dht=self.dht,
  143. average_parameters=True,
  144. average_gradients=True,
  145. prefix=f"{self.prefix}_averaging",
  146. allreduce_timeout=self.averaging_timeout,
  147. client_mode=self.client_mode,
  148. **kwargs,
  149. )
  150. @property
  151. def local_step(self) -> int:
  152. return self.averager.local_step
  153. @property
  154. def is_synchronized(self) -> bool:
  155. return self.local_step >= self.collaboration_state.optimizer_step
  156. def is_alive(self) -> bool:
  157. return self.averager.is_alive()
  158. def load_state_from_peers(self, **kwargs):
  159. """Attempt to fetch the newest collaboration state from other peers"""
  160. with self.lock_collaboration_state:
  161. while True:
  162. try:
  163. self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  164. break
  165. except BaseException as e:
  166. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  167. continue
  168. self.local_samples_accumulated = self.local_steps_accumulated = 0
  169. self.reset_accumulated_grads_()
  170. self.update_scheduler()
  171. def step(self, batch_size: Optional[int] = None, **kwargs):
  172. """
  173. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  174. :param batch_size: optional override for batch_size_per_step from init
  175. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  176. """
  177. if self.batch_size_per_step is None:
  178. if batch_size is None:
  179. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  180. logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
  181. self.batch_size_per_step = batch_size
  182. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  183. if not self.is_synchronized:
  184. logger.log(self.status_loglevel, "Peer is out of sync.")
  185. self.load_state_from_peers()
  186. return
  187. if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
  188. logger.warning(
  189. f"Training step took {get_dht_time() - self.last_step_time}, "
  190. f"but metadata expired in {self.metadata_expiration} s."
  191. )
  192. self.accumulate_grads_(batch_size)
  193. with self.lock_local_progress:
  194. self.local_samples_accumulated += batch_size
  195. self.local_steps_accumulated += 1
  196. self.performance_ema.update(num_processed=batch_size)
  197. self.should_report_progress.set()
  198. if not self.collaboration_state.ready_for_step:
  199. return
  200. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  201. self.collaboration_state = self._fetch_state()
  202. self.collaboration_state_updated.set()
  203. if not self.is_synchronized:
  204. self.load_state_from_peers()
  205. return
  206. with self.performance_ema.pause(), self.lock_collaboration_state:
  207. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  208. self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
  209. current_step, group_info = self.averager.local_step, None
  210. if self.collaboration_state.num_peers > 1:
  211. mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
  212. weight = self.local_samples_accumulated / mean_samples_per_worker
  213. try:
  214. group_info = self.averager.step(
  215. weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
  216. )
  217. if group_info:
  218. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  219. # update our current step if we averaged with another peer that was at a more recent step
  220. for peer, peer_step in group_info.items():
  221. if isinstance(peer_step, int):
  222. current_step = max(current_step, peer_step)
  223. else:
  224. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  225. except BaseException as e:
  226. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  227. else:
  228. logger.log(
  229. self.status_loglevel,
  230. f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
  231. )
  232. self.opt.step()
  233. self.reset_accumulated_grads_()
  234. self.local_samples_accumulated = self.local_steps_accumulated = 0
  235. self.collaboration_state.register_step(current_step + 1)
  236. self.averager.local_step = current_step + 1
  237. self.collaboration_state_updated.set()
  238. self.update_scheduler()
  239. logger.log(self.status_loglevel, f"Optimizer step: done!")
  240. return group_info
  241. def step_aux(self, **kwargs):
  242. """
  243. Find and assist other peers in averaging without sending local gradients.
  244. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  245. """
  246. if not self.collaboration_state.ready_for_step:
  247. return
  248. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  249. self.collaboration_state = self._fetch_state()
  250. self.collaboration_state_updated.set()
  251. with self.lock_collaboration_state:
  252. current_step, group_info = self.averager.local_step, None
  253. try:
  254. group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
  255. if group_info:
  256. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  257. # update our current step if we averaged with another peer that was at a more recent step
  258. for peer, peer_step in group_info.items():
  259. if isinstance(peer_step, int):
  260. current_step = max(current_step, peer_step)
  261. else:
  262. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  263. except BaseException as e:
  264. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  265. self.collaboration_state.register_step(current_step + 1)
  266. self.averager.local_step = current_step + 1
  267. self.collaboration_state_updated.set()
  268. logger.log(self.status_loglevel, f"Optimizer step: done!")
  269. return group_info
  270. def _grad_buffers(self) -> Iterator[torch.Tensor]:
  271. """pytorch-internal gradient buffers"""
  272. for param_group in self.opt.param_groups:
  273. for param in param_group["params"]:
  274. if param.grad is None:
  275. yield torch.zeros_like(param)
  276. else:
  277. yield param.grad
  278. @torch.no_grad()
  279. def accumulated_grads(self) -> Iterator[torch.Tensor]:
  280. """local gradient accumulators"""
  281. if self.reuse_grad_buffers:
  282. yield from self._grad_buffers()
  283. elif self._grads is None:
  284. with torch.no_grad():
  285. self._grads = [
  286. torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
  287. ]
  288. yield from self._grads
  289. @torch.no_grad()
  290. def accumulate_grads_(self, batch_size: int):
  291. """add current gradients to grad accumulators (if any)"""
  292. if self.reuse_grad_buffers:
  293. return # user is responsible for accumulating gradients in .grad buffers
  294. alpha = float(batch_size) / self.batch_size_per_step
  295. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  296. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  297. @torch.no_grad()
  298. def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
  299. if self.reuse_grad_buffers:
  300. return
  301. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  302. grad_buf[...] = grad_acc.to(grad_buf.device)
  303. if scale_by is not None:
  304. grad_buf.mul_(scale_by)
  305. @torch.no_grad()
  306. def reset_accumulated_grads_(self):
  307. if self.reuse_grad_buffers:
  308. self.opt.zero_grad()
  309. else:
  310. for grad_buf in self.accumulated_grads():
  311. grad_buf.zero_()
  312. def report_training_progress(self):
  313. """Periodically publish metadata and the current number of samples accumulated towards the next step"""
  314. while self.is_alive():
  315. self.should_report_progress.wait()
  316. self.should_report_progress.clear()
  317. with self.lock_local_progress:
  318. current_time = get_dht_time()
  319. local_state_info = TrainingState(
  320. peer_id=self.averager.peer_id.to_bytes(),
  321. step=self.local_step,
  322. samples_accumulated=self.local_samples_accumulated,
  323. samples_per_second=self.performance_ema.samples_per_second,
  324. time=current_time,
  325. client_mode=self.averager.client_mode,
  326. )
  327. self.dht.store(
  328. key=self.training_progress_key,
  329. subkey=self._local_public_key,
  330. value=local_state_info.dict(),
  331. expiration_time=current_time + self.metadata_expiration,
  332. return_future=True,
  333. )
  334. def check_collaboration_state_periodically(self):
  335. """
  336. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  337. """
  338. while self.is_alive():
  339. time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
  340. if self.collaboration_state_updated.wait(time_to_next_update):
  341. self.collaboration_state_updated.clear()
  342. continue # if state was updated externally, reset timer
  343. with self.lock_collaboration_state:
  344. self.collaboration_state = self._fetch_state()
  345. def _fetch_state(self) -> CollaborationState:
  346. """Read performance statistics reported by peers, estimate progress towards next batch"""
  347. response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  348. current_time = get_dht_time()
  349. if not isinstance(response, dict) or len(response) == 0:
  350. logger.log(self.status_loglevel, f"Found no active peers: {response}")
  351. local_eta_next_step = (
  352. max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
  353. )
  354. return CollaborationState(
  355. self.local_step,
  356. self.local_samples_accumulated,
  357. self.target_batch_size,
  358. num_peers=0,
  359. num_clients=0,
  360. eta_next_step=current_time + local_eta_next_step,
  361. next_fetch_time=current_time + self.default_refresh_period,
  362. )
  363. valid_peer_states = [
  364. TrainingState.parse_obj(peer_state.value)
  365. for peer_state in response.values()
  366. if peer_state.value is not None
  367. ]
  368. num_peers = len(valid_peer_states)
  369. num_clients = sum(state.client_mode for state in valid_peer_states)
  370. global_optimizer_step = self.local_step
  371. for state in valid_peer_states:
  372. if not state.client_mode:
  373. global_optimizer_step = max(global_optimizer_step, state.step)
  374. total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
  375. for state in valid_peer_states:
  376. total_samples_per_second += state.samples_per_second
  377. if state.step >= global_optimizer_step:
  378. total_samples_accumulated += state.samples_accumulated
  379. estimated_current_samples += (
  380. state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
  381. )
  382. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  383. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  384. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  385. estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
  386. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  387. time_to_next_fetch = float(
  388. np.clip(
  389. a=estimated_time_to_next_step * num_peers / expected_max_peers,
  390. a_min=self.min_refresh_period,
  391. a_max=self.max_refresh_period,
  392. )
  393. )
  394. logger.log(
  395. self.status_loglevel,
  396. f"{self.prefix} accumulated {total_samples_accumulated} samples from "
  397. f"{num_peers} peers for step #{global_optimizer_step}. "
  398. f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
  399. )
  400. return CollaborationState(
  401. global_optimizer_step,
  402. total_samples_accumulated,
  403. target_batch_size=self.target_batch_size,
  404. num_peers=num_peers,
  405. num_clients=num_clients,
  406. eta_next_step=current_time + estimated_time_to_next_step,
  407. next_fetch_time=current_time + time_to_next_fetch,
  408. )
  409. def zero_grad(self, *args, **kwargs):
  410. if self.reuse_grad_buffers:
  411. raise ValueError(
  412. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  413. f"call zero_grad manually. Gradients will be refreshed internally."
  414. )
  415. return self.opt.zero_grad(*args, **kwargs)
  416. def update_scheduler(self):
  417. if self.scheduler:
  418. while self.scheduler._step_count < self.local_step:
  419. self.scheduler.step()
  420. def shutdown(self):
  421. logger.debug("Shutting down averager...")
  422. self.averager.shutdown()
  423. logger.debug("Sending goodbye to peers...")
  424. self.dht.store(
  425. self.training_progress_key,
  426. subkey=self._local_public_key,
  427. value=None,
  428. expiration_time=get_dht_time() + self.metadata_expiration,
  429. )
  430. logger.debug(f"{self.__class__.__name__} is shut down.")
  431. def __del__(self):
  432. self.shutdown()