123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- import ctypes
- import threading
- from functools import partial
- from contextlib import nullcontext
- from copy import deepcopy
- import multiprocessing as mp
- from itertools import zip_longest
- from typing import Iterable
- import torch
- import torch.nn as nn
- import torch.utils.data
- import torch_xla.core.xla_model as xm
- import torch_xla.distributed.xla_multiprocessing as xmp
- import torch_xla.distributed.parallel_loader as pl
- from hivemind.utils.logging import get_logger
- logger = get_logger(__name__)
- class TPUManager(mp.Process):
- """Auxiliary class that manages model training over an array of TPU cores"""
- def __init__(self,
- model,
- dataset,
- *,
- collate_fn: callable = None,
- nprocs: int = 8,
- prefetch: int = 16,
- batch_size_per_device: int = 1,
- grad_accumulation_steps: int = 1,
- seed_base: int = 42,
- start: bool):
- super().__init__()
- self.lock = mp.Lock()
- self.nprocs, self.prefetch, self.seed_base = nprocs, prefetch, seed_base
- self.batch_size_per_device, self.grad_accumulation_steps = batch_size_per_device, grad_accumulation_steps
- self.collate_fn = collate_fn
- self.step_triggered, self.step_finished = mp.Event(), mp.Event()
- self._synchronizer = TPUSynchronizer(model)
- self._data_manager = TPUDataManager(dataset, nprocs, prefetch)
- # shared fields for communicating statistics after each step
- self.should_load_parameters = mp.Value(ctypes.c_bool, False)
- self.gradients_accumulated = mp.Value(ctypes.c_long, 0)
- self.loss_accumulated = mp.Value(ctypes.c_double, 0)
- if start:
- self.start()
- def run(self):
- thread = threading.Thread(
- target=partial(xmp.spawn, self.runner, nprocs=self.nprocs, start_method='fork'),
- daemon=True)
- thread.start()
- thread.join()
- def update_model_parameters(self, new_host_parameters):
- """Schedule TPUs to update model parameters during at the beginning of the next step"""
- with self.lock, torch.no_grad():
- self._synchronizer.set_host_parameters(new_host_parameters)
- self.should_load_parameters.value = True
- def get_aggregated_gradients(self):
- """Get current accumulated gradients from the master model"""
- with self.lock, torch.no_grad():
- return self._synchronizer.get_aggregated_gradients()
- def zero_grad(self):
- """Reset master accumulated gradients to zeros"""
- with self.lock, torch.no_grad():
- for param in self._synchronizer.master_model.parameters():
- param.grad.zero_()
- def step(self):
- """run forward/backward step with all TPUs, collect gradients"""
- self.loss_accumulated.value = self.gradients_accumulated.value = 0
- self.step_finished.clear()
- self.step_triggered.set()
- self.step_finished.wait()
- return self.loss_accumulated.value, self.gradients_accumulated.value
- def runner(self, tpu_index):
- """Run training steps from the perspective of a single TPU core"""
- # acquire the (unique) Cloud TPU core corresponding to this process's index
- device = xm.xla_device()
- logger.info(f"Process {tpu_index} is using {xm.xla_real_devices([str(device)])[0]}")
- # set random seed for
- torch.manual_seed(self.seed_base + tpu_index)
- # use staged init to minimize peak RAM usage
- for init_index in range(xm.xrt_world_size()):
- xm.rendezvous(f'init_{init_index}')
- if tpu_index == init_index:
- model = self._synchronizer.get_device_model_replica(device)
- data_loader = self._data_manager.get_device_dataloader(
- batch_size=self.batch_size_per_device, num_workers=0, collate_fn=self.collate_fn, pin_memory=False)
- data_loader_iter = iter(data_loader)
- logger.info(f"Process {tpu_index} initialized.")
- xm.rendezvous('init_finished')
- while True:
- self.step_triggered.wait()
- xm.rendezvous('before_step')
- if xm.is_master_ordinal():
- self.step_triggered.clear()
- if bool(self.should_load_parameters.value):
- with self.lock if xm.is_master_ordinal() else nullcontext():
- self._synchronizer.send_params_to_device(model)
- self.should_load_parameters.value = False
- ### compute loss and gradients
- loss = 0.0
- for i in range(self.grad_accumulation_steps):
- inputs = next(data_loader_iter)
- outputs = model(**inputs)
- loss_i = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- loss_i = loss_i / (self.grad_accumulation_steps * self.nprocs)
- loss_i.backward()
- loss += loss_i
- del inputs, outputs, loss_i
- ### aggregate gradients from TPUs
- with self.lock if xm.is_master_ordinal() else nullcontext():
- self._synchronizer.aggregate_grads_on_host(model, add=True)
- # clear aggregated gradients from all devices
- model.zero_grad()
- ### accumulate statistics to host
- loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0)
- xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,))
- def _mark_step_finished(self, loss):
- self.gradients_accumulated.value = self.batch_size_per_device * self.nprocs * self.grad_accumulation_steps
- self.loss_accumulated.value = float(loss)
- self.step_finished.set()
- class TPUSynchronizer:
- """An auxiliary class for manipulating parameters and gradients without producing a ton of XLA graphs"""
- def __init__(self, model: nn.Module):
- self.master_model = model.share_memory()
- for param in self.master_model.parameters():
- if param.grad is None:
- param.grad = torch.zeros_like(param)
- param.grad = param.grad.share_memory_()
- def get_device_model_replica(self, device: torch.device, tie_weights: bool = True):
- replica = deepcopy(self.master_model).to(device)
- if tie_weights:
- replica.tie_weights()
- for param in replica.parameters():
- param.grad = torch.zeros_like(param, device=device)
- return replica
- def set_host_parameters(self, new_host_parameters):
- return self._assign(source=self.master_model.parameters(), target=new_host_parameters, add=False, strict=True)
- def get_aggregated_gradients(self):
- return [param.grad for param in self.master_model.parameters()]
- def send_params_to_device(self, replica: nn.Module):
- """Copy params from master_model to this device_model replica"""
- with torch.no_grad():
- replica_params = list(replica.parameters())
- master_params = list(self.master_model.parameters())
- master_params = xm.send_cpu_data_to_device(master_params, xm.xla_device())
- self._assign(source=master_params, target=replica_params, add=False)
- xm.rendezvous("params_replicated")
- def aggregate_grads_on_host(self, replica: nn.Module, *, add: bool):
- """Aggregate grads from all tpu devices and move them to host"""
- with torch.no_grad():
- replica_grads = [param.grad for param in replica.parameters()]
- replica_grads = xm.all_reduce(xm.REDUCE_SUM, replica_grads, scale=1.0)
- master_grads = [hp.grad for hp in self.master_model.parameters()]
- xm.do_on_ordinals(lambda *replica_grads: self._assign(source=replica_grads, target=master_grads, add=add),
- data=tuple(replica_grads), ordinals=(0,))
- # ^-- do_on_ordinals already runs rendezvous at the end
- def _assign(self, source: Iterable[torch.Tensor], target: Iterable[torch.Tensor], add: bool, strict: bool = False):
- for source_tensor, target_tensor in zip_longest(source, target):
- assert source_tensor is not None or target_tensor is not None, "Source and target length must match exactly"
- if strict:
- assert source_tensor.shape == target_tensor.shape
- assert source_tensor.device == target_tensor.device
- assert source_tensor.dtype == target_tensor.dtype
- if add:
- target_tensor.add_(source_tensor)
- else:
- target_tensor.copy_(source_tensor)
- class TPUDataManager:
- """An auxiliary class that loads centralized dataset from master into multiple TPU devices"""
- def __init__(self, dataset: torch.utils.data.Dataset, nprocs: int, master_prefetch: int = 16):
- self.dataset, self.nprocs = dataset, nprocs
- self.device_queues = [mp.Queue(master_prefetch) for _ in range(nprocs)]
- self._loader_thread = threading.Thread(target=self._load_data_into_queues)
- self._loader_thread.start()
- def _load_data_into_queues(self):
- try:
- for i, batch in enumerate(self.dataset):
- self.device_queues[i % self.nprocs].put(batch)
- finally:
- logger.warning("Minibatch generator finished.")
- def get_device_dataloader(self, **kwargs):
- data_loader = torch.utils.data.DataLoader(QueueDataset(self.device_queues[xm.get_ordinal()]), **kwargs)
- return pl.ParallelLoader(data_loader, [xm.xla_device()]).per_device_loader(xm.xla_device())
- class QueueDataset(torch.utils.data.IterableDataset):
- """A dataset that ceaselessly iterates over a queue"""
- def __init__(self, queue: mp.Queue):
- super().__init__()
- self.queue = queue
- def __iter__(self):
- while True:
- yield self.queue.get()
- def __len__(self):
- return 10 ** 12 # TODO deprecate this when the issue is resolved: https://github.com/googlecolab/colabtools/issues/2237
|