training.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """ An extension of averager that supports common optimization use cases. """
  2. from itertools import chain
  3. from threading import Lock, Event
  4. from typing import Sequence, Dict, Iterator, Optional
  5. from contextlib import nullcontext
  6. import torch
  7. from hivemind.client.averaging import DecentralizedAverager
  8. from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
  9. logger = get_logger(__name__)
  10. class TrainingAverager(DecentralizedAverager):
  11. """
  12. A high-level interface to DecentralizedAverager that averages trainable params or gradients for an optimizer.
  13. This averager implements a number of typical use cases that arise in collaborative optimization
  14. - averaging parameters or gradients or both (in future, this will support averaging learning rates as well)
  15. - this peer's weight (e.g. based on its batch size) can be specified via averager.step(weight=...)
  16. - when out of sync, the averager will load the entire optimizer state from an up-to-date peer
  17. :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
  18. :param average_parameters: whether or not to average model parameters in self.step(...)
  19. :param average_gradients: whether or not to average model gradients in self.step(...)
  20. :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
  21. :param scheduler: if specified, averager stores scheduler state
  22. :param initialize_optimizer: if True, this will run a speculative optimizer step with
  23. zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
  24. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  25. :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
  26. :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
  27. """
  28. def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
  29. average_opt_statistics: Sequence[str] = (), scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
  30. extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
  31. self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
  32. self.opt_statistics = tuple(average_opt_statistics)
  33. self.average_parameters, self.average_gradients = average_parameters, average_gradients
  34. self.lock_averager_step = Lock()
  35. self.averaging_ready_event = Event()
  36. self.update = None
  37. self.scheduler = scheduler
  38. if initialize_optimizer:
  39. initialize_optimizer_state(opt) # note: this will run one optimizer step!
  40. with torch.no_grad():
  41. averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
  42. super().__init__(averaged_tensors=averaged_tensors, **kwargs)
  43. @torch.no_grad()
  44. def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
  45. """ Average optimizer weights and gradients with peers.
  46. :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
  47. modifications occur during averaging step
  48. :param wait: if True waits, otherwise returns Future
  49. """
  50. if not wait:
  51. return run_in_background(self.step, data_lock, wait=True, **kwargs)
  52. # if data_lock is supplied, tensors might change during averaging, so we need to copy them
  53. use_old_local_tensors = data_lock is not None
  54. if data_lock is None:
  55. data_lock = nullcontext()
  56. local_tensors = list(self.local_tensors())
  57. with self.lock_averager_step:
  58. # fill averager's tensors with current local tensors
  59. with data_lock, self.get_tensors() as averaged_tensors:
  60. if use_old_local_tensors:
  61. old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
  62. assert len(local_tensors) == len(
  63. averaged_tensors), "The number of optimized parameters should not change."
  64. for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
  65. averaged_tensor[...] = local_tensor.cpu().float()
  66. # find a group and hopefully average tensors with peers, scaled by peer's weight
  67. gathered = super().step(**kwargs)
  68. if gathered is not None:
  69. # load averaged tensors back into model
  70. with self.get_tensors() as averaged_tensors:
  71. if len(averaged_tensors) != len(local_tensors):
  72. raise RuntimeError("The number of optimized parameters should not change.")
  73. self.update = []
  74. if use_old_local_tensors:
  75. # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
  76. # losing local updates that might have occurred during averaging
  77. for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
  78. old_local_tensors):
  79. self.update.append(averaged_tensor.to(dtype=local_tensor.dtype,
  80. device=local_tensor.device) - \
  81. old_local_tensor.to(dtype=local_tensor.dtype,
  82. device=local_tensor.device))
  83. else:
  84. for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
  85. self.update.append(averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device))
  86. self.local_step += 1
  87. self.averaging_ready_event.set()
  88. return gathered
  89. def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
  90. """
  91. Iterate local trainer's tensors that should be averaged with peers
  92. :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
  93. Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
  94. """
  95. if self.average_parameters:
  96. for param_group in self.opt.param_groups:
  97. yield from param_group['params']
  98. if self.average_gradients:
  99. for param_group in self.opt.param_groups:
  100. for param in param_group['params']:
  101. if param.grad is not None:
  102. yield param.grad
  103. elif replace_none:
  104. yield torch.zeros_like(param)
  105. for stats in self.opt_statistics:
  106. for param_group in self.opt.param_groups:
  107. for param in param_group['params']:
  108. yield self.opt.state[param][stats]
  109. yield from iter(self.extra_tensors)
  110. def get_current_state(self):
  111. """
  112. Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
  113. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  114. """
  115. with torch.no_grad():
  116. optimized_parameters = tuple(param.detach().cpu() for param_group in self.opt.param_groups
  117. for param in param_group['params'])
  118. extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
  119. optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
  120. scheduler_state = self.scheduler.state_dict() if self.scheduler is not None else None
  121. metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
  122. optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)
  123. return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
  124. def load_state_from_peers(self, **kwargs):
  125. """
  126. Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
  127. :returns: whether or the averager succeeded in loading parameters
  128. """
  129. parameters_and_extras = [param for param_group in self.opt.param_groups for param in param_group['params']]
  130. parameters_and_extras.extend(self.extra_tensors)
  131. num_local_tensors = len(parameters_and_extras)
  132. loaded_state = super().load_state_from_peers(**kwargs)
  133. if loaded_state is None:
  134. return
  135. metadata, flat_tensors = loaded_state
  136. loaded_parameters_and_extras = flat_tensors[:num_local_tensors]
  137. loaded_opt_tensors = flat_tensors[num_local_tensors:]
  138. with torch.no_grad():
  139. for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
  140. local_param[...] = loaded_param
  141. load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
  142. self.local_step = max(self.local_step, metadata['step'])
  143. if self.scheduler is not None:
  144. if 'scheduler_state' not in metadata:
  145. logger.warning("Scheduler was passed, but there is no key 'scheduler_state' found in state")
  146. else:
  147. self.scheduler.load_state_dict(metadata['scheduler_state'])
  148. def initialize_optimizer_state(opt: torch.optim.Optimizer):
  149. for param_group in opt.param_groups:
  150. for param in param_group['params']:
  151. if param.grad is None:
  152. (0 * param.sum()).backward()
  153. opt.step()
  154. def dump_optimizer_state(opt: torch.optim.Optimizer):
  155. """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """
  156. with torch.no_grad():
  157. flat_metadata, flat_tensors = [], []
  158. for elem in nested_flatten(opt.state_dict()):
  159. if isinstance(elem, torch.Tensor):
  160. flat_metadata.append(dict(type='tensor', index=len(flat_tensors)))
  161. flat_tensors.append(elem.cpu())
  162. else:
  163. flat_metadata.append(dict(type='value', value=elem))
  164. return flat_metadata, flat_tensors
  165. def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
  166. flat_optimizer_state = []
  167. for elem in flat_metadata:
  168. if elem.get('type') == 'tensor' and isinstance(elem.get('index'), int):
  169. flat_optimizer_state.append(flat_tensors[elem['index']])
  170. elif elem.get('type') == 'value' and 'value' in elem:
  171. flat_optimizer_state.append(elem['value'])
  172. with torch.no_grad():
  173. try:
  174. return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
  175. except StopIteration:
  176. return optimizer