collaborative.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass
  4. from threading import Event, Lock, Thread
  5. from typing import Dict, Iterator, Optional
  6. import numpy as np
  7. import torch
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.dht import DHT
  10. from hivemind.dht.crypto import RSASignatureValidator
  11. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  12. from hivemind.optim.base import DecentralizedOptimizerBase
  13. from hivemind.optim.training import TrainingAverager
  14. from hivemind.optim.grad_scaler import HivemindGradScaler
  15. from hivemind.optim.performance_ema import PerformanceEMA
  16. from hivemind.utils import get_dht_time, get_logger
  17. logger = get_logger(__name__)
  18. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  19. @dataclass(frozen=False)
  20. class CollaborationState:
  21. optimizer_step: int
  22. samples_accumulated: int
  23. target_batch_size: int
  24. num_peers: int
  25. num_clients: int
  26. eta_next_step: float
  27. next_fetch_time: float
  28. @property
  29. def ready_for_step(self):
  30. return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
  31. def register_step(self, local_step: int):
  32. self.optimizer_step = max(local_step, self.optimizer_step)
  33. self.samples_accumulated = 0
  34. self.eta_next_step = float("inf")
  35. class TrainingState(BaseModel):
  36. peer_id: bytes
  37. step: conint(ge=0, strict=True)
  38. samples_accumulated: conint(ge=0, strict=True)
  39. samples_per_second: confloat(ge=0.0, strict=True)
  40. time: StrictFloat
  41. client_mode: StrictBool
  42. class TrainingProgressSchema(BaseModel):
  43. progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
  44. class CollaborativeOptimizer(DecentralizedOptimizerBase):
  45. """
  46. An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
  47. These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
  48. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
  49. :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
  50. * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
  51. * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
  52. :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  53. :param dht: a running hivemind.DHT daemon connected to other peers
  54. :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
  55. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  56. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  57. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  58. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  59. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  60. :param expected_drift_peers: assume that this many new peers can join between steps
  61. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
  62. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  63. refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
  64. :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
  65. :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
  66. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
  67. :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
  68. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  69. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  70. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  71. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  72. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  73. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  74. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  75. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  76. the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
  77. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  78. :param kwargs: additional parameters forwarded to DecentralizedAverager
  79. :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  80. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  81. """
  82. def __init__(
  83. self,
  84. opt: torch.optim.Optimizer,
  85. *,
  86. dht: DHT,
  87. prefix: str,
  88. target_batch_size: int,
  89. batch_size_per_step: Optional[int] = None,
  90. scheduler: Optional[LRSchedulerBase] = None,
  91. min_refresh_period: float = 0.5,
  92. max_refresh_period: float = 30,
  93. default_refresh_period: float = 3,
  94. expected_drift_peers: float = 3,
  95. expected_drift_rate: float = 0.2,
  96. performance_ema_alpha: float = 0.1,
  97. metadata_expiration: float = 60.0,
  98. averaging_timeout: Optional[float] = None,
  99. load_state_timeout: float = 600.0,
  100. step_tolerance: int = 1,
  101. reuse_grad_buffers: bool = False,
  102. accumulate_grads_on: Optional[torch.device] = None,
  103. client_mode: bool = False,
  104. verbose: bool = False,
  105. **kwargs,
  106. ):
  107. super().__init__(opt, dht)
  108. signature_validator = RSASignatureValidator()
  109. self._local_public_key = signature_validator.local_public_key
  110. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  111. if reuse_grad_buffers and accumulate_grads_on is not None:
  112. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  113. self.prefix, self.scheduler = prefix, scheduler
  114. self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
  115. self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
  116. min_refresh_period,
  117. max_refresh_period,
  118. default_refresh_period,
  119. )
  120. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  121. self.averaging_timeout = averaging_timeout
  122. self.load_state_timeout = load_state_timeout
  123. self.metadata_expiration = metadata_expiration
  124. self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
  125. self.client_mode, self.step_tolerance = client_mode, step_tolerance
  126. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  127. self.averager = self._make_averager(**kwargs)
  128. self._step_supports_amp_scaling = self.reuse_grad_buffers # enable custom execution with torch GradScaler
  129. self.training_progress_key = f"{self.prefix}_progress"
  130. self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
  131. self.local_updates_accumulated = 0 # a number of calls to step() since last optimizer update
  132. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  133. self.last_step_time = None
  134. self.collaboration_state = self._fetch_state()
  135. self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
  136. self.lock_local_progress, self.should_report_progress = Lock(), Event()
  137. self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
  138. self.progress_reporter.start()
  139. self.collaboration_state_updater = Thread(
  140. target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
  141. )
  142. self.collaboration_state_updater.start()
  143. def _make_averager(self, **kwargs):
  144. return TrainingAverager(
  145. self.opt,
  146. dht=self.dht,
  147. average_parameters=True,
  148. average_gradients=True,
  149. prefix=f"{self.prefix}_averaging",
  150. allreduce_timeout=self.averaging_timeout,
  151. client_mode=self.client_mode,
  152. **kwargs,
  153. )
  154. @property
  155. def local_step(self) -> int:
  156. return self.averager.local_step
  157. @property
  158. def is_synchronized(self) -> bool:
  159. return self.local_step >= self.collaboration_state.optimizer_step
  160. @property
  161. def is_within_tolerance(self) -> bool:
  162. return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
  163. def is_alive(self) -> bool:
  164. return self.averager.is_alive()
  165. def load_state_from_peers(self, **kwargs):
  166. """Attempt to fetch the newest collaboration state from other peers"""
  167. with self.lock_collaboration_state:
  168. while True:
  169. try:
  170. self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  171. break
  172. except KeyboardInterrupt:
  173. raise
  174. except BaseException as e:
  175. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  176. continue
  177. self.local_samples_accumulated = self.local_updates_accumulated = 0
  178. self.reset_accumulated_grads_()
  179. self.update_scheduler()
  180. def state_dict(self) -> dict:
  181. state_dict = super().state_dict()
  182. state_dict["state"]["collaborative_step"] = self.local_step
  183. return state_dict
  184. def load_state_dict(self, state_dict: dict):
  185. if "collaborative_step" in state_dict["state"]:
  186. self.averager.local_step = state_dict["state"].pop("collaborative_step")
  187. return super().load_state_dict(state_dict)
  188. def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
  189. """
  190. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  191. :param batch_size: optional override for batch_size_per_step from init
  192. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
  193. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  194. """
  195. if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
  196. raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
  197. if self.batch_size_per_step is None:
  198. if batch_size is None:
  199. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  200. logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
  201. self.batch_size_per_step = batch_size
  202. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  203. if not self.is_synchronized and not self.is_within_tolerance:
  204. logger.log(self.status_loglevel, "Peer is out of sync.")
  205. self.load_state_from_peers()
  206. return
  207. elif not self.is_synchronized and self.is_within_tolerance:
  208. self.averager.local_step = self.collaboration_state.optimizer_step
  209. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
  210. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  211. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  212. self.local_samples_accumulated = self.local_steps_accumulated = 0
  213. self.reset_accumulated_grads_()
  214. self.should_report_progress.set()
  215. return
  216. if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
  217. logger.warning(
  218. f"Training step took {get_dht_time() - self.last_step_time}, "
  219. f"but metadata expired in {self.metadata_expiration} s."
  220. )
  221. self.accumulate_grads_(batch_size)
  222. with self.lock_local_progress:
  223. self.local_samples_accumulated += batch_size
  224. self.local_updates_accumulated += 1
  225. self.performance_ema.update(num_processed=batch_size)
  226. self.should_report_progress.set()
  227. if not self.collaboration_state.ready_for_step:
  228. return
  229. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  230. with self.performance_ema.pause(), self.lock_collaboration_state:
  231. self.collaboration_state = self._fetch_state()
  232. self.collaboration_state_updated.set()
  233. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  234. self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
  235. if grad_scaler is not None:
  236. with grad_scaler.running_global_step():
  237. assert grad_scaler.unscale_(self)
  238. current_step, group_info = self.averager.local_step, None
  239. if self.collaboration_state.num_peers > 1:
  240. mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
  241. weight = self.local_samples_accumulated / mean_samples_per_worker
  242. try:
  243. group_info = self.averager.step(
  244. weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
  245. )
  246. if group_info:
  247. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  248. # update our current step if we averaged with another peer that was at a more recent step
  249. for peer, peer_step in group_info.items():
  250. if isinstance(peer_step, int):
  251. current_step = max(current_step, peer_step)
  252. else:
  253. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  254. except BaseException as e:
  255. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  256. else:
  257. logger.log(
  258. self.status_loglevel,
  259. f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
  260. )
  261. if grad_scaler is not None:
  262. with grad_scaler.running_global_step():
  263. assert grad_scaler.step(self)
  264. else:
  265. self.opt.step()
  266. self.reset_accumulated_grads_()
  267. self.local_samples_accumulated = self.local_updates_accumulated = 0
  268. self.collaboration_state.register_step(current_step + 1)
  269. self.averager.local_step = current_step + 1
  270. self.collaboration_state_updated.set()
  271. self.update_scheduler()
  272. if grad_scaler is not None:
  273. with grad_scaler.running_global_step():
  274. assert grad_scaler.update()
  275. logger.log(self.status_loglevel, f"Optimizer step: done!")
  276. return group_info
  277. def step_aux(self, **kwargs):
  278. """
  279. Find and assist other peers in averaging without sending local gradients.
  280. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  281. """
  282. if not self.collaboration_state.ready_for_step:
  283. return
  284. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
  285. self.collaboration_state = self._fetch_state()
  286. self.collaboration_state_updated.set()
  287. with self.lock_collaboration_state:
  288. current_step, group_info = self.averager.local_step, None
  289. try:
  290. group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
  291. if group_info:
  292. logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
  293. # update our current step if we averaged with another peer that was at a more recent step
  294. for peer, peer_step in group_info.items():
  295. if isinstance(peer_step, int):
  296. current_step = max(current_step, peer_step)
  297. else:
  298. logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
  299. except BaseException as e:
  300. logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
  301. self.collaboration_state.register_step(current_step + 1)
  302. self.averager.local_step = current_step + 1
  303. self.collaboration_state_updated.set()
  304. logger.log(self.status_loglevel, f"Optimizer step: done!")
  305. return group_info
  306. def _grad_buffers(self) -> Iterator[torch.Tensor]:
  307. """pytorch-internal gradient buffers"""
  308. for param_group in self.opt.param_groups:
  309. for param in param_group["params"]:
  310. if param.grad is None:
  311. yield torch.zeros_like(param)
  312. else:
  313. yield param.grad
  314. @torch.no_grad()
  315. def accumulated_grads(self) -> Iterator[torch.Tensor]:
  316. """local gradient accumulators"""
  317. if self.reuse_grad_buffers:
  318. yield from self._grad_buffers()
  319. return
  320. if self._grads is None:
  321. self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
  322. yield from self._grads
  323. @torch.no_grad()
  324. def accumulate_grads_(self, batch_size: int):
  325. """add current gradients to grad accumulators (if any)"""
  326. if self.reuse_grad_buffers:
  327. # user is responsible for accumulating gradients in .grad buffers
  328. assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
  329. else:
  330. alpha = float(batch_size) / self.batch_size_per_step
  331. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  332. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  333. @torch.no_grad()
  334. def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
  335. if not self.reuse_grad_buffers:
  336. for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
  337. grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
  338. if scale_by is not None:
  339. for grad_buf in self._grad_buffers():
  340. grad_buf.mul_(scale_by)
  341. @torch.no_grad()
  342. def reset_accumulated_grads_(self):
  343. for grad_buf in self.accumulated_grads():
  344. grad_buf.zero_()
  345. def report_training_progress(self):
  346. """Periodically publish metadata and the current number of samples accumulated towards the next step"""
  347. while self.is_alive():
  348. self.should_report_progress.wait()
  349. self.should_report_progress.clear()
  350. with self.lock_local_progress:
  351. current_time = get_dht_time()
  352. local_state_info = TrainingState(
  353. peer_id=self.averager.peer_id.to_bytes(),
  354. step=self.local_step,
  355. samples_accumulated=self.local_samples_accumulated,
  356. samples_per_second=self.performance_ema.samples_per_second,
  357. time=current_time,
  358. client_mode=self.averager.client_mode,
  359. )
  360. self.dht.store(
  361. key=self.training_progress_key,
  362. subkey=self._local_public_key,
  363. value=local_state_info.dict(),
  364. expiration_time=current_time + self.metadata_expiration,
  365. return_future=True,
  366. )
  367. def check_collaboration_state_periodically(self):
  368. """
  369. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  370. """
  371. while self.is_alive():
  372. time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
  373. if self.collaboration_state_updated.wait(time_to_next_update):
  374. self.collaboration_state_updated.clear()
  375. continue # if state was updated externally, reset timer
  376. with self.lock_collaboration_state:
  377. self.collaboration_state = self._fetch_state()
  378. def _fetch_state(self) -> CollaborationState:
  379. """Read performance statistics reported by peers, estimate progress towards next batch"""
  380. response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  381. current_time = get_dht_time()
  382. if not isinstance(response, dict) or len(response) == 0:
  383. logger.log(self.status_loglevel, f"Found no active peers: {response}")
  384. samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
  385. local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
  386. return CollaborationState(
  387. self.local_step,
  388. self.local_samples_accumulated,
  389. self.target_batch_size,
  390. num_peers=0,
  391. num_clients=0,
  392. eta_next_step=current_time + local_eta_next_step,
  393. next_fetch_time=current_time + self.default_refresh_period,
  394. )
  395. valid_peer_states = [
  396. TrainingState.parse_obj(peer_state.value)
  397. for peer_state in response.values()
  398. if peer_state.value is not None
  399. ]
  400. num_peers = len(valid_peer_states)
  401. num_clients = sum(state.client_mode for state in valid_peer_states)
  402. global_optimizer_step = self.local_step
  403. for state in valid_peer_states:
  404. if not state.client_mode:
  405. global_optimizer_step = max(global_optimizer_step, state.step)
  406. total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
  407. for state in valid_peer_states:
  408. total_samples_per_second += state.samples_per_second
  409. if state.step == global_optimizer_step:
  410. total_samples_accumulated += state.samples_accumulated
  411. estimated_current_samples += (
  412. state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
  413. )
  414. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  415. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  416. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  417. estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
  418. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  419. time_to_next_fetch = float(
  420. np.clip(
  421. a=estimated_time_to_next_step * num_peers / expected_max_peers,
  422. a_min=self.min_refresh_period,
  423. a_max=self.max_refresh_period,
  424. )
  425. )
  426. logger.log(
  427. self.status_loglevel,
  428. f"{self.prefix} accumulated {total_samples_accumulated} samples from "
  429. f"{num_peers} peers for step #{global_optimizer_step}. "
  430. f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
  431. )
  432. return CollaborationState(
  433. global_optimizer_step,
  434. total_samples_accumulated,
  435. target_batch_size=self.target_batch_size,
  436. num_peers=num_peers,
  437. num_clients=num_clients,
  438. eta_next_step=current_time + estimated_time_to_next_step,
  439. next_fetch_time=current_time + time_to_next_fetch,
  440. )
  441. def zero_grad(self, *args, **kwargs):
  442. if self.reuse_grad_buffers:
  443. raise ValueError(
  444. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  445. f"call zero_grad manually. Gradients will be refreshed internally."
  446. )
  447. return self.opt.zero_grad(*args, **kwargs)
  448. def update_scheduler(self):
  449. if self.scheduler:
  450. while self.scheduler._step_count < self.local_step:
  451. self.scheduler.step()
  452. def shutdown(self):
  453. logger.debug("Shutting down averager...")
  454. self.averager.shutdown()
  455. logger.debug("Sending goodbye to peers...")
  456. self.dht.store(
  457. self.training_progress_key,
  458. subkey=self._local_public_key,
  459. value=None,
  460. expiration_time=get_dht_time() + self.metadata_expiration,
  461. )
  462. logger.debug(f"{self.__class__.__name__} is shut down.")
  463. def __del__(self):
  464. self.shutdown()