training.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """ An extension of averager that supports common optimization use cases. """
  2. from concurrent.futures import ThreadPoolExecutor
  3. from contextlib import nullcontext
  4. from itertools import chain
  5. from threading import Lock
  6. from typing import Sequence, Dict, Iterator, Optional
  7. import torch
  8. from hivemind.averaging import DecentralizedAverager
  9. from hivemind.utils import nested_flatten, nested_pack, get_logger
  10. from hivemind.proto.runtime_pb2 import CompressionType
  11. logger = get_logger(__name__)
  12. class TrainingAverager(DecentralizedAverager):
  13. """
  14. A high-level interface to DecentralizedAverager that averages trainable params or gradients for an optimizer.
  15. This averager implements a number of typical use cases that arise in collaborative optimization
  16. - averaging parameters or gradients or both (in future, this will support averaging learning rates as well)
  17. - this peer's weight (e.g. based on its batch size) can be specified via averager.step(weight=...)
  18. - when out of sync, the averager will load the entire optimizer state from an up-to-date peer
  19. :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
  20. :param average_parameters: whether or not to average model parameters in self.step(...)
  21. :param average_gradients: whether or not to average model gradients in self.step(...)
  22. :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
  23. :param initialize_optimizer: if True, this will run a speculative optimizer step with
  24. zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
  25. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  26. :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
  27. :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
  28. """
  29. def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
  30. average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
  31. initialize_optimizer: bool = True, compression_type=None, **kwargs):
  32. self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
  33. self.opt_statistics = tuple(average_opt_statistics)
  34. self.average_parameters, self.average_gradients = average_parameters, average_gradients
  35. self.step_executor = ThreadPoolExecutor(max_workers=1)
  36. self.lock_averager_step = Lock()
  37. if initialize_optimizer:
  38. initialize_optimizer_state(opt) # note: this will run one optimizer step!
  39. with torch.no_grad():
  40. averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
  41. assert average_parameters and average_gradients and not average_opt_statistics
  42. compression_type = [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT
  43. for g in averaged_tensors]
  44. for g in averaged_tensors:
  45. print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
  46. super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)
  47. def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
  48. """
  49. Average optimizer weights and gradients with peers.
  50. :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
  51. modifications occur during averaging step
  52. """
  53. if not wait:
  54. return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
  55. # if data_lock is supplied, tensors might change during averaging, so we need to copy them
  56. use_old_local_tensors = data_lock is not None
  57. if data_lock is None:
  58. data_lock = nullcontext()
  59. local_tensors = list(self.local_tensors())
  60. with self.lock_averager_step, torch.no_grad():
  61. # fill averager's tensors with current local tensors
  62. with data_lock, self.get_tensors() as averaged_tensors:
  63. if use_old_local_tensors:
  64. old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
  65. assert len(local_tensors) == len(
  66. averaged_tensors), "The number of optimized parameters should not change."
  67. for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
  68. averaged_tensor[...] = local_tensor.cpu().float()
  69. # find a group and hopefully average tensors with peers, use batch sizes as weights
  70. gathered = super().step(**kwargs)
  71. if gathered is not None:
  72. # load averaged tensors back into model
  73. with data_lock, self.get_tensors() as averaged_tensors:
  74. if len(averaged_tensors) != len(local_tensors):
  75. raise RuntimeError("The number of optimized parameters should not change.")
  76. if use_old_local_tensors:
  77. # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
  78. # losing local updates that might have occurred during averaging
  79. for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
  80. old_local_tensors):
  81. local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
  82. device=local_tensor.device) - \
  83. old_local_tensor.to(dtype=local_tensor.dtype,
  84. device=local_tensor.device)
  85. else:
  86. for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
  87. local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
  88. self.local_step += 1
  89. return gathered
  90. def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
  91. """
  92. Iterate local trainer's tensors that should be averaged with peers
  93. :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
  94. Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
  95. """
  96. if self.average_parameters:
  97. for param_group in self.opt.param_groups:
  98. yield from param_group['params']
  99. if self.average_gradients:
  100. for param_group in self.opt.param_groups:
  101. for param in param_group['params']:
  102. if param.grad is not None:
  103. yield param.grad
  104. elif replace_none:
  105. yield torch.zeros_like(param)
  106. for stats in self.opt_statistics:
  107. for param_group in self.opt.param_groups:
  108. for param in param_group['params']:
  109. yield self.opt.state[param][stats]
  110. yield from iter(self.extra_tensors)
  111. def get_current_state(self):
  112. """
  113. Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
  114. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  115. """
  116. with torch.no_grad():
  117. optimized_parameters = tuple(param.detach().cpu() for param_group in self.opt.param_groups
  118. for param in param_group['params'])
  119. extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
  120. optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
  121. metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
  122. return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
  123. def load_state_from_peers(self, **kwargs):
  124. """
  125. Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
  126. :returns: whether or the averager succeeded in loading parameters
  127. """
  128. parameters_and_extras = [param for param_group in self.opt.param_groups for param in param_group['params']]
  129. parameters_and_extras.extend(self.extra_tensors)
  130. num_local_tensors = len(parameters_and_extras)
  131. loaded_state = super().load_state_from_peers(**kwargs)
  132. if loaded_state is None:
  133. return
  134. metadata, flat_tensors = loaded_state
  135. loaded_parameters_and_extras = flat_tensors[:num_local_tensors]
  136. loaded_opt_tensors = flat_tensors[num_local_tensors:]
  137. with torch.no_grad():
  138. for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
  139. local_param[...] = loaded_param
  140. load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
  141. self.local_step = max(self.local_step, metadata['step'])
  142. def initialize_optimizer_state(opt: torch.optim.Optimizer):
  143. for param_group in opt.param_groups:
  144. for param in param_group['params']:
  145. if param.grad is None:
  146. (0 * param.sum()).backward()
  147. opt.step()
  148. def dump_optimizer_state(opt: torch.optim.Optimizer):
  149. """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """
  150. with torch.no_grad():
  151. flat_metadata, flat_tensors = [], []
  152. for elem in nested_flatten(opt.state_dict()):
  153. if isinstance(elem, torch.Tensor):
  154. flat_metadata.append(dict(type='tensor', index=len(flat_tensors)))
  155. flat_tensors.append(elem.cpu())
  156. else:
  157. flat_metadata.append(dict(type='value', value=elem))
  158. return flat_metadata, flat_tensors
  159. def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
  160. flat_optimizer_state = []
  161. for elem in flat_metadata:
  162. if elem.get('type') == 'tensor' and isinstance(elem.get('index'), int):
  163. flat_optimizer_state.append(flat_tensors[elem['index']])
  164. elif elem.get('type') == 'value' and 'value' in elem:
  165. flat_optimizer_state.append(elem['value'])
  166. with torch.no_grad():
  167. try:
  168. return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
  169. except StopIteration:
  170. return optimizer