state_averager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. """ An extension of averager that supports common optimization use cases. """
  2. import logging
  3. from asyncio import Future
  4. from concurrent.futures import ThreadPoolExecutor
  5. from itertools import chain
  6. from threading import Event
  7. from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
  8. import torch
  9. import hivemind
  10. from hivemind.averaging import DecentralizedAverager
  11. from hivemind.compression import CompressionInfo, TensorRole
  12. from hivemind.utils import get_logger, nested_flatten, nested_pack
  13. logger = get_logger(__name__)
  14. Parameters = Iterable[torch.Tensor]
  15. ParamGroups = Iterable[Dict[str, Any]]
  16. TorchOptimizer = torch.optim.Optimizer
  17. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  18. OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
  19. SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
  20. class TrainingStateAverager(DecentralizedAverager):
  21. """
  22. An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
  23. and any other variables that define the local training state (e.g. batchnorm moving averages).
  24. TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
  25. The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
  26. If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
  27. Example:
  28. >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
  29. >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
  30. >>> avgr.load_state_from_peers()
  31. >>> for i, batch in enumerate(training_dataloader):
  32. >>> loss = compute_loss(model, batch)
  33. >>> loss.backward()
  34. >>> avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
  35. :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
  36. it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
  37. TrainingStateAverager.step(..., optimizer_step=True)
  38. :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
  39. :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
  40. :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
  41. :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
  42. :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
  43. state tensors. If False, user must make sure that all tensors are pre-initialized at init.
  44. By default, initialize optimizer unless it already has some state tensors to begin with.
  45. :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
  46. :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
  47. This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
  48. :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
  49. For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
  50. :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
  51. :param parameter_names: optionally provide parameter names in the same order as in params
  52. :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
  53. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  54. :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
  55. :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
  56. """
  57. def __init__(
  58. self,
  59. *,
  60. dht: hivemind.DHT,
  61. optimizer: Union[TorchOptimizer, OptimizerFactory],
  62. params: Optional[Union[Parameters, ParamGroups]] = None,
  63. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  64. initialize_optimizer: Optional[bool] = None,
  65. offload_optimizer: bool = False,
  66. custom_gradients: bool = False,
  67. reuse_tensors: bool = False,
  68. sync_epoch_when_averaging: bool = False,
  69. parameter_names: Optional[Sequence[str]] = None,
  70. average_opt_statistics: Sequence[str] = (),
  71. extra_tensors: Sequence[torch.Tensor] = (),
  72. status_loglevel: int = logging.DEBUG,
  73. **kwargs,
  74. ):
  75. average_opt_statistics = tuple(average_opt_statistics)
  76. assert all(isinstance(key, str) for key in average_opt_statistics)
  77. if offload_optimizer and reuse_tensors:
  78. logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
  79. if custom_gradients and not offload_optimizer:
  80. logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
  81. params_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
  82. self.status_loglevel = status_loglevel
  83. self.reuse_tensors = reuse_tensors
  84. self.offload_optimizer = offload_optimizer
  85. self.custom_gradients = custom_gradients
  86. self.main_parameters, self.parameter_names = main_parameters, parameter_names
  87. self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
  88. self.optimizer, self.scheduler = self._init_components(
  89. params_groups, optimizer, scheduler, initialize_optimizer
  90. )
  91. self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
  92. self.sync_epoch_when_averaging = sync_epoch_when_averaging
  93. self.local_epoch = 0
  94. self.step_executor = ThreadPoolExecutor(max_workers=1)
  95. self.finished_optimizer_step = Event()
  96. self.finished_averaging_round = Event()
  97. self.pending_update = Future()
  98. self.pending_update.set_result(None)
  99. super().__init__(
  100. dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
  101. )
  102. @staticmethod
  103. def _check_params(
  104. optimizer: Union[TorchOptimizer, OptimizerFactory],
  105. param_groups: Optional[Union[Parameters, ParamGroups]],
  106. parameter_names: Optional[Sequence[str]],
  107. ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
  108. """Get and verify parameters, groups and names"""
  109. if param_groups is None:
  110. assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
  111. param_groups = optimizer.param_groups
  112. param_groups = tuple(param_groups)
  113. if all(isinstance(p, torch.Tensor) for p in param_groups):
  114. param_groups = (dict(params=param_groups),)
  115. for group in param_groups:
  116. assert isinstance(group, dict) and group.get("params") is not None
  117. assert all(isinstance(p, torch.Tensor) for p in group["params"])
  118. parameters = tuple(chain(*(group["params"] for group in param_groups)))
  119. if parameter_names is None:
  120. parameter_names = tuple(i for i in range(len(parameters)))
  121. parameter_names = tuple(nested_flatten(parameter_names))
  122. assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
  123. assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
  124. return param_groups, parameters, parameter_names
  125. def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
  126. """Create a new tensor for averaging or reuse the existing one"""
  127. if self.reuse_tensors:
  128. assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
  129. if not source_tensor.is_shared():
  130. source_tensor.share_memory_()
  131. return source_tensor
  132. else:
  133. averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
  134. return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
  135. def _init_components(
  136. self,
  137. param_groups: ParamGroups,
  138. optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
  139. scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
  140. initialize_optimizer: Optional[bool],
  141. ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
  142. """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
  143. assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
  144. optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
  145. scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
  146. if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
  147. raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
  148. if self.offload_optimizer and not optimizer_is_factory:
  149. raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
  150. # create optimizer
  151. if optimizer_is_factory:
  152. if self.offload_optimizer:
  153. for param in self._averaged_parameters:
  154. if param.grad is None:
  155. param.grad = torch.zeros_like(param)
  156. next_index = 0
  157. param_groups_for_optimizer = []
  158. for param_group in param_groups:
  159. num_params = len(param_group["params"])
  160. averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
  161. param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
  162. next_index += num_params
  163. assert next_index == len(self._averaged_parameters)
  164. else:
  165. param_groups_for_optimizer = param_groups
  166. optimizer = optimizer_or_factory(param_groups_for_optimizer)
  167. else:
  168. optimizer = optimizer_or_factory
  169. # optionally initialize optimizer state dict
  170. if initialize_optimizer is None:
  171. initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
  172. logger.log(
  173. self.status_loglevel,
  174. "Initializing optimizer manually since it has no tensors in state dict. "
  175. "To override this, please provide initialize_optimizer=False",
  176. )
  177. if initialize_optimizer:
  178. initialize_optimizer_state_(optimizer) # note: this will run one optimizer step!
  179. # create LR scheduler
  180. if scheduler_is_factory:
  181. assert callable(scheduler_or_factory)
  182. scheduler = scheduler_or_factory(optimizer)
  183. else:
  184. scheduler = scheduler_or_factory
  185. # verify optimizer and scheduler
  186. assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
  187. if self.offload_optimizer or self.reuse_tensors:
  188. for param_group in optimizer.param_groups:
  189. for param in param_group["params"]:
  190. assert param.is_shared()
  191. assert isinstance(scheduler, (LRSchedulerBase, type(None)))
  192. if scheduler is not None:
  193. assert scheduler.optimizer == optimizer
  194. return optimizer, scheduler
  195. def _local_tensors(self) -> Iterator[torch.Tensor]:
  196. """Iterate local trainer's tensors that should be averaged with peers"""
  197. for param_group in self.optimizer.param_groups:
  198. yield from param_group["params"]
  199. for stats in self.opt_keys_for_averaging:
  200. for param_group in self.optimizer.param_groups:
  201. for param in param_group["params"]:
  202. yield self.optimizer.state[param][stats]
  203. yield from self.extra_tensors
  204. @torch.no_grad()
  205. def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
  206. """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
  207. assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
  208. assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
  209. assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
  210. assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
  211. local_tensors = tuple(self._local_tensors())
  212. local_non_parameters = local_tensors[len(self._averaged_parameters) :]
  213. averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
  214. averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
  215. averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
  216. assert len(averaged_tensors) == len(local_tensors)
  217. for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
  218. assert local_tensor.shape == averaged_tensor.shape
  219. if averaged_tensor.grad is not None:
  220. logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
  221. return averaged_tensors
  222. def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
  223. """Get CompressionInfo for each state tensor, accounting for its role and specification"""
  224. tensor_infos = []
  225. for param, param_name in zip(self.main_parameters, self.parameter_names):
  226. tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
  227. for stats_name in self.opt_keys_for_averaging:
  228. opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
  229. assert len(opt_parameters) == len(self.parameter_names)
  230. for param, param_name in zip(opt_parameters, self.parameter_names):
  231. tensor_infos.append(
  232. CompressionInfo.from_tensor(
  233. self.optimizer.state[param][stats_name],
  234. key=(param_name, stats_name),
  235. role=TensorRole.OPTIMIZER,
  236. )
  237. )
  238. for i, extra_tensor in enumerate(self.extra_tensors):
  239. tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
  240. return tuple(tensor_infos)
  241. def step(
  242. self,
  243. wait_for_delayed_update: bool = None,
  244. apply_delayed_updates: bool = True,
  245. increment_epoch: bool = False,
  246. optimizer_step: bool = False,
  247. zero_grad: bool = False,
  248. delay_optimizer_step: bool = False,
  249. averaging_round: bool = False,
  250. delay_averaging: Optional[bool] = None,
  251. grad_scaler: Optional[hivemind.GradScaler] = None,
  252. averaging_opts: Optional[Dict[str, Any]] = None,
  253. ):
  254. """
  255. Perform one or several possible actions, depending on the specified keyword args.
  256. The actions will be performed in the same order as specified below:
  257. :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
  258. by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
  259. :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
  260. :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
  261. :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
  262. :param zero_grad: if True, reset local gradients after performing optimizer step
  263. :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
  264. :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
  265. :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
  266. :param delay_averaging: if True, perform averaging in background and apply results in a future step
  267. by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
  268. :param averaging_opts: a dict of keyword arguments forwarded into averaging round
  269. """
  270. if delay_averaging is None:
  271. delay_averaging = delay_optimizer_step
  272. if wait_for_delayed_update is None:
  273. wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
  274. assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
  275. if optimizer_step or averaging_round or zero_grad:
  276. assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
  277. if delay_optimizer_step:
  278. assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
  279. assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
  280. if averaging_opts and not averaging_round:
  281. logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
  282. output = None
  283. if wait_for_delayed_update:
  284. if not self.pending_update.done():
  285. logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
  286. output = self.pending_update.result()
  287. if self.pending_update.done() and self.pending_update.exception():
  288. logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
  289. if apply_delayed_updates:
  290. if self.finished_averaging_round.is_set():
  291. if not self.reuse_tensors:
  292. self._apply_averaging_results_()
  293. logger.log(self.status_loglevel, "Received parameters from background averaging round")
  294. self.finished_averaging_round.clear()
  295. if self.finished_optimizer_step.is_set():
  296. if self.offload_optimizer:
  297. self._apply_optimizer_results_()
  298. logger.log(self.status_loglevel, "Received parameters from background optimizer step")
  299. self.finished_optimizer_step.clear()
  300. if increment_epoch:
  301. self.local_epoch += 1
  302. if optimizer_step or zero_grad or averaging_round:
  303. assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
  304. if self.offload_optimizer and not self.custom_gradients:
  305. self._load_local_grads_into_optimizer_()
  306. self.pending_update = self.step_executor.submit(
  307. self._do,
  308. optimizer_step,
  309. zero_grad,
  310. averaging_round,
  311. grad_scaler,
  312. **averaging_opts or {},
  313. )
  314. if (optimizer_step or zero_grad) and not delay_optimizer_step:
  315. self.finished_optimizer_step.wait()
  316. self.finished_optimizer_step.clear()
  317. if self.offload_optimizer:
  318. self._apply_optimizer_results_()
  319. logger.log(self.status_loglevel, "Finished optimizer step")
  320. if averaging_round and not delay_averaging:
  321. self.finished_averaging_round.wait()
  322. self.finished_averaging_round.clear()
  323. if not self.reuse_tensors:
  324. self._apply_averaging_results_()
  325. logger.log(self.status_loglevel, "Finished averaging round")
  326. if not delay_averaging:
  327. try:
  328. output = self.pending_update.result()
  329. finally:
  330. self.finished_averaging_round.clear()
  331. self.finished_optimizer_step.clear()
  332. return output
  333. def _do(
  334. self,
  335. optimizer_step: bool,
  336. zero_grad: bool,
  337. averaging_round: bool,
  338. grad_scaler: Optional[hivemind.GradScaler],
  339. **kwargs,
  340. ):
  341. """
  342. Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
  343. This method is meant to be called in the background executor.
  344. """
  345. try:
  346. if optimizer_step:
  347. logger.log(self.status_loglevel, f"Running optimizer step")
  348. if grad_scaler is None:
  349. self.optimizer.step()
  350. else:
  351. with grad_scaler.running_global_step():
  352. assert grad_scaler.step(self.optimizer)
  353. if grad_scaler is not None:
  354. with grad_scaler.running_global_step():
  355. assert grad_scaler.update()
  356. self._update_scheduler()
  357. if zero_grad:
  358. logger.log(self.status_loglevel, f"Running zero grad")
  359. self.optimizer.zero_grad()
  360. if self.offload_optimizer:
  361. for parameter in self.main_parameters:
  362. if parameter.grad is not None:
  363. parameter.grad.zero_()
  364. self.finished_optimizer_step.set()
  365. if averaging_round:
  366. if not self.reuse_tensors:
  367. self._load_local_tensors_into_averager_()
  368. try:
  369. gathered = super().step(gather=self.local_epoch, **kwargs)
  370. logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
  371. except BaseException as e:
  372. logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
  373. self.finished_averaging_round.set()
  374. gathered = {}
  375. self.finished_averaging_round.set()
  376. if self.sync_epoch_when_averaging:
  377. old_epoch = self.local_epoch
  378. for peer_epoch in gathered.values():
  379. self.local_epoch = max(self.local_epoch, peer_epoch)
  380. if self.local_epoch != old_epoch:
  381. logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
  382. self._update_scheduler()
  383. except Exception as e:
  384. logger.exception(e)
  385. self.finished_optimizer_step.set()
  386. self.finished_averaging_round.set()
  387. @torch.no_grad()
  388. def _load_local_grads_into_optimizer_(self):
  389. """Copy local gradients into the gradient buffers of the offloaded optimizer"""
  390. assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
  391. opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
  392. for main_param, opt_param in zip(self.main_parameters, opt_parameters):
  393. if main_param.grad is not None:
  394. opt_param.grad.copy_(main_param.grad, non_blocking=True)
  395. @torch.no_grad()
  396. def _apply_optimizer_results_(self):
  397. """Copy parameters from offloaded optimizer to the main model"""
  398. assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
  399. with self.lock_averaged_tensors:
  400. offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
  401. assert len(offloaded_parameters) == len(
  402. self.main_parameters
  403. ), "Optimizer parameters changed during training"
  404. for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
  405. main_param.copy_(offloaded_param, non_blocking=True)
  406. @torch.no_grad()
  407. def _load_local_tensors_into_averager_(self):
  408. """Copy local tensors into the averaging buffers"""
  409. assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
  410. with self.get_tensors() as averaged_tensors:
  411. for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
  412. averaged_tensor.copy_(local_tensor, non_blocking=True)
  413. @torch.no_grad()
  414. def _apply_averaging_results_(self):
  415. """Copy averaged tensors into their respective local tensors"""
  416. assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
  417. with self.get_tensors() as averaged_tensors:
  418. local_tensors = list(self._local_tensors())
  419. assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
  420. for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
  421. local_tensor.copy_(averaged_tensor, non_blocking=True)
  422. def get_current_state(self):
  423. """
  424. Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
  425. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  426. """
  427. with torch.no_grad():
  428. optimized_parameters = tuple(
  429. param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
  430. )
  431. parameter_infos = [
  432. CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
  433. for param, key in zip(optimized_parameters, self.parameter_names)
  434. ]
  435. extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
  436. extra_infos = [
  437. CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
  438. for i, extra_tensor in enumerate(extra_tensors)
  439. ]
  440. optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
  441. optimizer_infos = [
  442. CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
  443. for i, opt_tensor in enumerate(optimizer_tensors)
  444. ]
  445. metadata = dict(
  446. epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
  447. )
  448. all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
  449. all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
  450. return metadata, all_tensors, all_tensor_infos
  451. def load_state_from_peers(self, **kwargs):
  452. """
  453. Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
  454. :returns: whether or the averager succeeded in loading parameters
  455. """
  456. parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
  457. num_parameters_and_extras = len(parameters_and_extras)
  458. loaded_state = super().load_state_from_peers(**kwargs)
  459. if loaded_state is None:
  460. return
  461. metadata, flat_tensors = loaded_state
  462. if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
  463. logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
  464. return
  465. loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
  466. loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
  467. if num_parameters_and_extras != len(loaded_parameters_and_extras):
  468. logger.error("Failed to load state from peer, received parameters, extras or metadata.")
  469. return
  470. try:
  471. load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
  472. except StopIteration:
  473. logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
  474. return
  475. with torch.no_grad():
  476. for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
  477. local_param.copy_(loaded_param, non_blocking=True)
  478. self.local_epoch = metadata["epoch"]
  479. self._update_scheduler()
  480. def _update_scheduler(self):
  481. """Increase the scheduler state until it becomes synchronized with local epoch"""
  482. if self.scheduler:
  483. while self.scheduler._step_count <= self.local_epoch:
  484. self.scheduler.step()
  485. def initialize_optimizer_state_(opt: torch.optim.Optimizer):
  486. """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
  487. flat_params = tuple(param for group in opt.param_groups for param in group["params"])
  488. old_grads = []
  489. for param in flat_params:
  490. old_grads.append(param.grad)
  491. param.grad = torch.zeros_like(param)
  492. opt.step()
  493. for param, old_grad in zip(flat_params, old_grads):
  494. param.grad = old_grad
  495. def dump_optimizer_state(opt: torch.optim.Optimizer):
  496. """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
  497. with torch.no_grad():
  498. flat_metadata, flat_tensors = [], []
  499. for elem in nested_flatten(opt.state_dict()):
  500. if isinstance(elem, torch.Tensor):
  501. flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
  502. flat_tensors.append(elem.cpu())
  503. else:
  504. flat_metadata.append(dict(type="value", value=elem))
  505. return flat_metadata, flat_tensors
  506. def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
  507. """Load a state obtained by dump_optimizer_state back into the optimizer"""
  508. flat_optimizer_state = []
  509. for elem in flat_metadata:
  510. if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
  511. flat_optimizer_state.append(flat_tensors[elem["index"]])
  512. elif elem.get("type") == "value" and "value" in elem:
  513. flat_optimizer_state.append(elem["value"])
  514. return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))