collaborative.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass
  4. from threading import Event, Lock, Thread
  5. from typing import Dict, Iterator, Optional
  6. import numpy as np
  7. import torch
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.dht import DHT
  10. from hivemind.dht.crypto import RSASignatureValidator
  11. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  12. from hivemind.optim.base import DecentralizedOptimizerBase
  13. from hivemind.optim.grad_scaler import HivemindGradScaler
  14. from hivemind.optim.training_averager import TrainingAverager
  15. from hivemind.utils import get_dht_time, get_logger
  16. from hivemind.utils.performance_ema import PerformanceEMA
  17. logger = get_logger(__name__)
  18. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  19. @dataclass(frozen=False)
  20. class CollaborationState:
  21. optimizer_step: int
  22. samples_accumulated: int
  23. target_batch_size: int
  24. num_peers: int
  25. num_clients: int
  26. eta_next_step: float
  27. next_fetch_time: float
  28. @property
  29. def ready_for_step(self):
  30. return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
  31. def register_step(self, local_step: int):
  32. self.optimizer_step = max(local_step, self.optimizer_step)
  33. self.samples_accumulated = 0
  34. self.eta_next_step = float("inf")
  35. class TrainingState(BaseModel):
  36. peer_id: bytes
  37. step: conint(ge=0, strict=True)
  38. samples_accumulated: conint(ge=0, strict=True)
  39. samples_per_second: confloat(ge=0.0, strict=True)
  40. time: StrictFloat
  41. client_mode: StrictBool
  42. class TrainingProgressSchema(BaseModel):
  43. progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
  44. class CollaborativeOptimizer(DecentralizedOptimizerBase):
  45. """
  46. An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
  47. These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
  48. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
  49. :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
  50. Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
  51. CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
  52. :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
  53. * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
  54. * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
  55. :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  56. :param dht: a running hivemind.DHT daemon connected to other peers
  57. :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
  58. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  59. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  60. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  61. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  62. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  63. :param expected_drift_peers: assume that this many new peers can join between steps
  64. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
  65. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  66. refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
  67. :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
  68. :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
  69. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
  70. :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
  71. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  72. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  73. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  74. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  75. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  76. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  77. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  78. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  79. the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
  80. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  81. :param kwargs: additional parameters forwarded to DecentralizedAverager
  82. :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  83. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  84. """
  85. def __init__(
  86. self,
  87. opt: torch.optim.Optimizer,
  88. *,
  89. dht: DHT,
  90. prefix: str,
  91. target_batch_size: int,
  92. batch_size_per_step: Optional[int] = None,
  93. scheduler: Optional[LRSchedulerBase] = None,
  94. min_refresh_period: float = 0.5,
  95. max_refresh_period: float = 30,
  96. default_refresh_period: float = 3,
  97. expected_drift_peers: float = 3,
  98. expected_drift_rate: float = 0.2,
  99. performance_ema_alpha: float = 0.1,
  100. metadata_expiration: float = 60.0,
  101. averaging_timeout: Optional[float] = None,
  102. load_state_timeout: float = 600.0,
  103. step_tolerance: int = 1,
  104. reuse_grad_buffers: bool = False,
  105. accumulate_grads_on: Optional[torch.device] = None,
  106. client_mode: bool = False,
  107. verbose: bool = False,
  108. **kwargs,
  109. ):
  110. super().__init__(opt, dht)
  111. signature_validator = RSASignatureValidator()
  112. self._local_public_key = signature_validator.local_public_key
  113. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  114. if reuse_grad_buffers and accumulate_grads_on is not None:
  115. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  116. self.prefix, self.scheduler = prefix, scheduler
  117. self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
  118. self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
  119. min_refresh_period,
  120. max_refresh_period,
  121. default_refresh_period,
  122. )
  123. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  124. self.averaging_timeout = averaging_timeout
  125. self.load_state_timeout = load_state_timeout
  126. self.metadata_expiration = metadata_expiration
  127. self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
  128. self.client_mode, self.step_tolerance = client_mode, step_tolerance
  129. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  130. self.averager = self._make_averager(**kwargs)
  131. self._step_supports_amp_scaling = self.reuse_grad_buffers # enable custom execution with torch GradScaler
  132. self.training_progress_key = f"{self.prefix}_progress"
  133. self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
  134. self.local_updates_accumulated = 0 # a number of calls to step() since last optimizer update
  135. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  136. self.last_step_time = None
  137. self.collaboration_state = self._fetch_state()
  138. self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
  139. self.lock_local_progress, self.should_report_progress = Lock(), Event()
  140. self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
  141. self.progress_reporter.start()
  142. self.collaboration_state_updater = Thread(
  143. target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
  144. )
  145. self.collaboration_state_updater.start()
  146. def _make_averager(self, **kwargs):
  147. return TrainingAverager(
  148. self.opt,
  149. dht=self.dht,
  150. average_parameters=True,
  151. average_gradients=True,
  152. prefix=f"{self.prefix}_averaging",
  153. allreduce_timeout=self.averaging_timeout,
  154. client_mode=self.client_mode,
  155. **kwargs,
  156. )
  157. @property
  158. def local_step(self) -> int:
  159. return self.averager.local_step
  160. @property
  161. def is_synchronized(self) -> bool:
  162. return self.local_step >= self.collaboration_state.optimizer_step
  163. @property
  164. def is_within_tolerance(self) -> bool:
  165. return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
  166. def is_alive(self) -> bool:
  167. return self.averager.is_alive()
  168. def load_state_from_peers(self, **kwargs):
  169. """Attempt to fetch the newest collaboration state from other peers"""
  170. with self.lock_collaboration_state:
  171. while True:
  172. try:
  173. self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  174. break
  175. except KeyboardInterrupt:
  176. raise
  177. except BaseException as e:
  178. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  179. continue
  180. self.local_samples_accumulated = self.local_updates_accumulated = 0
  181. self.reset_accumulated_grads_()
  182. self.update_scheduler()
  183. def state_dict(self) -> dict:
  184. state_dict = super().state_dict()
  185. state_dict["state"]["collaborative_step"] = self.local_step
  186. return state_dict
  187. def load_state_dict(self, state_dict: dict):
  188. if "collaborative_step" in state_dict["state"]:
  189. self.averager.local_step = state_dict["state"].pop("collaborative_step")
  190. return super().load_state_dict(state_dict)
  191. def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
  192. """
  193. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  194. :param batch_size: optional override for batch_size_per_step from init
  195. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
  196. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  197. """
  198. if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
  199. raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
  200. if self.batch_size_per_step is None:
  201. if batch_size is None:
  202. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  203. logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
  204. self.batch_size_per_step = batch_size
  205. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  206. if not self.is_synchronized and not self.is_within_tolerance:
  207. logger.log(self.status_loglevel, "Peer is out of sync")
  208. self.load_state_from_peers()
  209. return
  210. elif not self.is_synchronized and self.is_within_tolerance:
  211. self.averager.local_step = self.collaboration_state.optimizer_step
  212. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
  213. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  214. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  215. self.local_samples_accumulated = self.local_steps_accumulated = 0
  216. self.reset_accumulated_grads_()
  217. self.should_report_progress.set()
  218. return
  219. if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
  220. logger.warning(
  221. f"Training step took {get_dht_time() - self.last_step_time}, "
  222. f"but metadata expired in {self.metadata_expiration} s."
  223. )
  224. self.accumulate_grads_(batch_size)
  225. with self.lock_local_progress:
  226. self.local_samples_accumulated += batch_size
  227. self.local_updates_accumulated += 1
  228. self.performance_ema.update(task_size=batch_size)
  229. self.should_report_progress.set()
  230. if not self.collaboration_state.ready_for_step:
  231. return
  232. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  233. with self.performance_ema.pause(), self.lock_collaboration_state:
  234. self.collaboration_state = self._fetch_state()
  235. self.collaboration_state_updated.set()
  236. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  237. self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
  238. if grad_scaler is not None:
  239. with grad_scaler.running_global_step():
  240. assert grad_scaler.unscale_(self)
  241. current_step, group_info = self.averager.local_step, None
  242. if self.collaboration_state.num_peers > 1:
  243. mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
  244. weight = self.local_samples_accumulated / mean_samples_per_worker
  245. try:
  246. group_info = self.averager.step(
  247. weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
  248. )
  249. if group_info:
  250. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  251. # update our current step if we averaged with another peer that was at a more recent step
  252. for peer, peer_step in group_info.items():
  253. if isinstance(peer_step, int):
  254. current_step = max(current_step, peer_step)
  255. else:
  256. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  257. except BaseException as e:
  258. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
  259. else:
  260. logger.log(
  261. self.status_loglevel,
  262. f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
  263. )
  264. if grad_scaler is not None:
  265. with grad_scaler.running_global_step():
  266. assert grad_scaler.step(self)
  267. else:
  268. self.opt.step()
  269. self.reset_accumulated_grads_()
  270. self.local_samples_accumulated = self.local_updates_accumulated = 0
  271. self.collaboration_state.register_step(current_step + 1)
  272. self.averager.local_step = current_step + 1
  273. self.collaboration_state_updated.set()
  274. self.update_scheduler()
  275. if grad_scaler is not None:
  276. with grad_scaler.running_global_step():
  277. assert grad_scaler.update()
  278. if not self.averager.client_mode:
  279. self.averager.state_sharing_priority = self.local_step
  280. logger.log(self.status_loglevel, f"Optimizer step: done!")
  281. return group_info
  282. def step_aux(self, **kwargs):
  283. """
  284. Find and assist other peers in averaging without sending local gradients.
  285. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  286. """
  287. if not self.collaboration_state.ready_for_step:
  288. return
  289. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  290. self.collaboration_state = self._fetch_state()
  291. self.collaboration_state_updated.set()
  292. with self.lock_collaboration_state:
  293. current_step, group_info = self.averager.local_step, None
  294. try:
  295. group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
  296. if group_info:
  297. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  298. # update our current step if we averaged with another peer that was at a more recent step
  299. for peer, peer_step in group_info.items():
  300. if isinstance(peer_step, int):
  301. current_step = max(current_step, peer_step)
  302. else:
  303. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  304. except BaseException as e:
  305. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
  306. self.collaboration_state.register_step(current_step + 1)
  307. self.averager.local_step = current_step + 1
  308. self.collaboration_state_updated.set()
  309. logger.log(self.status_loglevel, f"Optimizer step: done!")
  310. return group_info
  311. def _grad_buffers(self) -> Iterator[torch.Tensor]:
  312. """pytorch-internal gradient buffers"""
  313. for param_group in self.opt.param_groups:
  314. for param in param_group["params"]:
  315. if param.grad is None:
  316. yield torch.zeros_like(param)
  317. else:
  318. yield param.grad
  319. @torch.no_grad()
  320. def accumulated_grads(self) -> Iterator[torch.Tensor]:
  321. """local gradient accumulators"""
  322. if self.reuse_grad_buffers:
  323. yield from self._grad_buffers()
  324. return
  325. if self._grads is None:
  326. self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
  327. yield from self._grads
  328. @torch.no_grad()
  329. def accumulate_grads_(self, batch_size: int):
  330. """add current gradients to grad accumulators (if any)"""
  331. if self.reuse_grad_buffers:
  332. # user is responsible for accumulating gradients in .grad buffers
  333. assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
  334. else:
  335. alpha = float(batch_size) / self.batch_size_per_step
  336. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  337. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  338. @torch.no_grad()
  339. def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
  340. if not self.reuse_grad_buffers:
  341. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  342. grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
  343. if scale_by is not None:
  344. for grad_buf in self._grad_buffers():
  345. grad_buf.mul_(scale_by)
  346. @torch.no_grad()
  347. def reset_accumulated_grads_(self):
  348. for grad_buf in self.accumulated_grads():
  349. grad_buf.zero_()
  350. def report_training_progress(self):
  351. """Periodically publish metadata and the current number of samples accumulated towards the next step"""
  352. while self.is_alive():
  353. self.should_report_progress.wait()
  354. self.should_report_progress.clear()
  355. with self.lock_local_progress:
  356. current_time = get_dht_time()
  357. local_state_info = TrainingState(
  358. peer_id=self.averager.peer_id.to_bytes(),
  359. step=self.local_step,
  360. samples_accumulated=self.local_samples_accumulated,
  361. samples_per_second=self.performance_ema.samples_per_second,
  362. time=current_time,
  363. client_mode=self.averager.client_mode,
  364. )
  365. self.dht.store(
  366. key=self.training_progress_key,
  367. subkey=self._local_public_key,
  368. value=local_state_info.dict(),
  369. expiration_time=current_time + self.metadata_expiration,
  370. return_future=True,
  371. )
  372. def check_collaboration_state_periodically(self):
  373. """
  374. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  375. """
  376. while self.is_alive():
  377. time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
  378. if self.collaboration_state_updated.wait(time_to_next_update):
  379. self.collaboration_state_updated.clear()
  380. continue # if state was updated externally, reset timer
  381. with self.lock_collaboration_state:
  382. self.collaboration_state = self._fetch_state()
  383. def _fetch_state(self) -> CollaborationState:
  384. """Read performance statistics reported by peers, estimate progress towards next batch"""
  385. response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  386. current_time = get_dht_time()
  387. if not isinstance(response, dict) or len(response) == 0:
  388. logger.log(self.status_loglevel, f"Found no active peers: {response}")
  389. samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
  390. local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
  391. return CollaborationState(
  392. self.local_step,
  393. self.local_samples_accumulated,
  394. self.target_batch_size,
  395. num_peers=0,
  396. num_clients=0,
  397. eta_next_step=current_time + local_eta_next_step,
  398. next_fetch_time=current_time + self.default_refresh_period,
  399. )
  400. valid_peer_states = [
  401. TrainingState.parse_obj(peer_state.value)
  402. for peer_state in response.values()
  403. if peer_state.value is not None
  404. ]
  405. num_peers = len(valid_peer_states)
  406. num_clients = sum(state.client_mode for state in valid_peer_states)
  407. global_optimizer_step = self.local_step
  408. for state in valid_peer_states:
  409. if not state.client_mode:
  410. global_optimizer_step = max(global_optimizer_step, state.step)
  411. total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
  412. for state in valid_peer_states:
  413. total_samples_per_second += state.samples_per_second
  414. if state.step == global_optimizer_step:
  415. total_samples_accumulated += state.samples_accumulated
  416. estimated_current_samples += (
  417. state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
  418. )
  419. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  420. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  421. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  422. estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
  423. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  424. time_to_next_fetch = float(
  425. np.clip(
  426. a=estimated_time_to_next_step * num_peers / expected_max_peers,
  427. a_min=self.min_refresh_period,
  428. a_max=self.max_refresh_period,
  429. )
  430. )
  431. logger.log(
  432. self.status_loglevel,
  433. f"{self.prefix} accumulated {total_samples_accumulated} samples from "
  434. f"{num_peers} peers for step #{global_optimizer_step}. "
  435. f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
  436. )
  437. return CollaborationState(
  438. global_optimizer_step,
  439. total_samples_accumulated,
  440. target_batch_size=self.target_batch_size,
  441. num_peers=num_peers,
  442. num_clients=num_clients,
  443. eta_next_step=current_time + estimated_time_to_next_step,
  444. next_fetch_time=current_time + time_to_next_fetch,
  445. )
  446. def zero_grad(self, *args, **kwargs):
  447. if self.reuse_grad_buffers:
  448. raise ValueError(
  449. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  450. f"call zero_grad manually. Gradients will be refreshed internally."
  451. )
  452. return self.opt.zero_grad(*args, **kwargs)
  453. def update_scheduler(self):
  454. if self.scheduler:
  455. while self.scheduler._step_count < self.local_step:
  456. self.scheduler.step()
  457. def shutdown(self):
  458. logger.debug("Shutting down averager...")
  459. self.averager.shutdown()
  460. logger.debug("Sending goodbye to peers...")
  461. self.dht.store(
  462. self.training_progress_key,
  463. subkey=self._local_public_key,
  464. value=None,
  465. expiration_time=get_dht_time() + self.metadata_expiration,
  466. )
  467. logger.debug(f"{self.__class__.__name__} is shut down")
  468. def __del__(self):
  469. self.shutdown()