tpu.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import ctypes
  2. import threading
  3. from functools import partial
  4. from contextlib import nullcontext
  5. from copy import deepcopy
  6. import multiprocessing as mp
  7. from itertools import zip_longest
  8. from typing import Iterable
  9. import torch
  10. import torch.nn as nn
  11. import torch.utils.data
  12. import torch_xla.core.xla_model as xm
  13. import torch_xla.distributed.xla_multiprocessing as xmp
  14. import torch_xla.distributed.parallel_loader as pl
  15. from hivemind.utils.logging import get_logger
  16. logger = get_logger(__name__)
  17. class TPUManager(mp.Process):
  18. """Auxiliary class that manages model training over an array of TPU cores"""
  19. def __init__(self,
  20. model,
  21. dataset,
  22. *,
  23. collate_fn: callable = None,
  24. nprocs: int = 8,
  25. prefetch: int = 16,
  26. batch_size_per_device: int = 1,
  27. grad_accumulation_steps: int = 1,
  28. seed_base: int = 42,
  29. start: bool):
  30. super().__init__()
  31. self.lock = mp.Lock()
  32. self.nprocs, self.prefetch, self.seed_base = nprocs, prefetch, seed_base
  33. self.batch_size_per_device, self.grad_accumulation_steps = batch_size_per_device, grad_accumulation_steps
  34. self.collate_fn = collate_fn
  35. self.step_triggered, self.step_finished = mp.Event(), mp.Event()
  36. self._synchronizer = TPUSynchronizer(model)
  37. self._data_manager = TPUDataManager(dataset, nprocs, prefetch)
  38. # shared fields for communicating statistics after each step
  39. self.should_load_parameters = mp.Value(ctypes.c_bool, False)
  40. self.gradients_accumulated = mp.Value(ctypes.c_long, 0)
  41. self.loss_accumulated = mp.Value(ctypes.c_double, 0)
  42. if start:
  43. self.start()
  44. def run(self):
  45. thread = threading.Thread(
  46. target=partial(xmp.spawn, self.runner, nprocs=self.nprocs, start_method='fork'),
  47. daemon=True)
  48. thread.start()
  49. thread.join()
  50. def update_model_parameters(self, new_host_parameters):
  51. """Schedule TPUs to update model parameters during at the beginning of the next step"""
  52. with self.lock, torch.no_grad():
  53. self._synchronizer.set_host_parameters(new_host_parameters)
  54. self.should_load_parameters.value = True
  55. def get_aggregated_gradients(self):
  56. """Get current accumulated gradients from the master model"""
  57. with self.lock, torch.no_grad():
  58. return self._synchronizer.get_aggregated_gradients()
  59. def zero_grad(self):
  60. """Reset master accumulated gradients to zeros"""
  61. with self.lock, torch.no_grad():
  62. for param in self._synchronizer.master_model.parameters():
  63. param.grad.zero_()
  64. def step(self):
  65. """run forward/backward step with all TPUs, collect gradients"""
  66. self.loss_accumulated.value = self.gradients_accumulated.value = 0
  67. self.step_finished.clear()
  68. self.step_triggered.set()
  69. self.step_finished.wait()
  70. return self.loss_accumulated.value, self.gradients_accumulated.value
  71. def runner(self, tpu_index):
  72. """Run training steps from the perspective of a single TPU core"""
  73. # acquire the (unique) Cloud TPU core corresponding to this process's index
  74. device = xm.xla_device()
  75. logger.info(f"Process {tpu_index} is using {xm.xla_real_devices([str(device)])[0]}")
  76. # set random seed for
  77. torch.manual_seed(self.seed_base + tpu_index)
  78. # use staged init to minimize peak RAM usage
  79. for init_index in range(xm.xrt_world_size()):
  80. xm.rendezvous(f'init_{init_index}')
  81. if tpu_index == init_index:
  82. model = self._synchronizer.get_device_model_replica(device)
  83. data_loader = self._data_manager.get_device_dataloader(
  84. batch_size=self.batch_size_per_device, num_workers=0, collate_fn=self.collate_fn, pin_memory=False)
  85. data_loader_iter = iter(data_loader)
  86. logger.info(f"Process {tpu_index} initialized.")
  87. xm.rendezvous('init_finished')
  88. while True:
  89. self.step_triggered.wait()
  90. xm.rendezvous('before_step')
  91. if xm.is_master_ordinal():
  92. self.step_triggered.clear()
  93. if bool(self.should_load_parameters.value):
  94. with self.lock if xm.is_master_ordinal() else nullcontext():
  95. self._synchronizer.send_params_to_device(model)
  96. self.should_load_parameters.value = False
  97. ### compute loss and gradients
  98. loss = 0.0
  99. for i in range(self.grad_accumulation_steps):
  100. inputs = next(data_loader_iter)
  101. outputs = model(**inputs)
  102. loss_i = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
  103. loss_i = loss_i / (self.grad_accumulation_steps * self.nprocs)
  104. loss_i.backward()
  105. loss += loss_i
  106. del inputs, outputs, loss_i
  107. ### aggregate gradients from TPUs
  108. with self.lock if xm.is_master_ordinal() else nullcontext():
  109. self._synchronizer.aggregate_grads_on_host(model, add=True)
  110. # clear aggregated gradients from all devices
  111. model.zero_grad()
  112. ### accumulate statistics to host
  113. loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0)
  114. xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,))
  115. def _mark_step_finished(self, loss):
  116. self.gradients_accumulated.value = self.batch_size_per_device * self.nprocs * self.grad_accumulation_steps
  117. self.loss_accumulated.value = float(loss)
  118. self.step_finished.set()
  119. class TPUSynchronizer:
  120. """An auxiliary class for manipulating parameters and gradients without producing a ton of XLA graphs"""
  121. def __init__(self, model: nn.Module):
  122. self.master_model = model.share_memory()
  123. for param in self.master_model.parameters():
  124. if param.grad is None:
  125. param.grad = torch.zeros_like(param)
  126. param.grad = param.grad.share_memory_()
  127. def get_device_model_replica(self, device: torch.device, tie_weights: bool = True):
  128. replica = deepcopy(self.master_model).to(device)
  129. if tie_weights:
  130. replica.tie_weights()
  131. for param in replica.parameters():
  132. param.grad = torch.zeros_like(param, device=device)
  133. return replica
  134. def set_host_parameters(self, new_host_parameters):
  135. return self._assign(source=self.master_model.parameters(), target=new_host_parameters, add=False, strict=True)
  136. def get_aggregated_gradients(self):
  137. return [param.grad for param in self.master_model.parameters()]
  138. def send_params_to_device(self, replica: nn.Module):
  139. """Copy params from master_model to this device_model replica"""
  140. with torch.no_grad():
  141. replica_params = list(replica.parameters())
  142. master_params = list(self.master_model.parameters())
  143. master_params = xm.send_cpu_data_to_device(master_params, xm.xla_device())
  144. self._assign(source=master_params, target=replica_params, add=False)
  145. xm.rendezvous("params_replicated")
  146. def aggregate_grads_on_host(self, replica: nn.Module, *, add: bool):
  147. """Aggregate grads from all tpu devices and move them to host"""
  148. with torch.no_grad():
  149. replica_grads = [param.grad for param in replica.parameters()]
  150. replica_grads = xm.all_reduce(xm.REDUCE_SUM, replica_grads, scale=1.0)
  151. master_grads = [hp.grad for hp in self.master_model.parameters()]
  152. xm.do_on_ordinals(lambda *replica_grads: self._assign(source=replica_grads, target=master_grads, add=add),
  153. data=tuple(replica_grads), ordinals=(0,))
  154. # ^-- do_on_ordinals already runs rendezvous at the end
  155. def _assign(self, source: Iterable[torch.Tensor], target: Iterable[torch.Tensor], add: bool, strict: bool = False):
  156. for source_tensor, target_tensor in zip_longest(source, target):
  157. assert source_tensor is not None or target_tensor is not None, "Source and target length must match exactly"
  158. if strict:
  159. assert source_tensor.shape == target_tensor.shape
  160. assert source_tensor.device == target_tensor.device
  161. assert source_tensor.dtype == target_tensor.dtype
  162. if add:
  163. target_tensor.add_(source_tensor)
  164. else:
  165. target_tensor.copy_(source_tensor)
  166. class TPUDataManager:
  167. """An auxiliary class that loads centralized dataset from master into multiple TPU devices"""
  168. def __init__(self, dataset: torch.utils.data.Dataset, nprocs: int, master_prefetch: int = 16):
  169. self.dataset, self.nprocs = dataset, nprocs
  170. self.device_queues = [mp.Queue(master_prefetch) for _ in range(nprocs)]
  171. self._loader_thread = threading.Thread(target=self._load_data_into_queues)
  172. self._loader_thread.start()
  173. def _load_data_into_queues(self):
  174. try:
  175. for i, batch in enumerate(self.dataset):
  176. self.device_queues[i % self.nprocs].put(batch)
  177. finally:
  178. logger.warning("Minibatch generator finished.")
  179. def get_device_dataloader(self, **kwargs):
  180. data_loader = torch.utils.data.DataLoader(QueueDataset(self.device_queues[xm.get_ordinal()]), **kwargs)
  181. return pl.ParallelLoader(data_loader, [xm.xla_device()]).per_device_loader(xm.xla_device())
  182. class QueueDataset(torch.utils.data.IterableDataset):
  183. """A dataset that ceaselessly iterates over a queue"""
  184. def __init__(self, queue: mp.Queue):
  185. super().__init__()
  186. self.queue = queue
  187. def __iter__(self):
  188. while True:
  189. yield self.queue.get()
  190. def __len__(self):
  191. return 10 ** 12 # TODO deprecate this when the issue is resolved: https://github.com/googlecolab/colabtools/issues/2237