offload.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import contextlib
  2. from typing import Type, Iterable, Dict, Union, Optional
  3. import multiprocessing as mp
  4. import torch
  5. from .wrapper import OptimizerWrapper
  6. class OffloadOptimizer(OptimizerWrapper):
  7. r""" A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM). """
  8. def __init__(
  9. self, param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]],
  10. optim_cls: Type[torch.optim.Optimizer], *args, full_sync: bool = True,
  11. offload_device=torch.device('cpu'), offload_dtype: Optional[torch.dtype] = None, **kwargs):
  12. param_groups = list(param_groups)
  13. if not isinstance(param_groups[0], dict):
  14. param_groups = [{'params': param_groups}]
  15. super().__init__(optim_cls(param_groups, *args, **kwargs))
  16. self.full_sync = full_sync
  17. self.lock = mp.Lock()
  18. with torch.no_grad():
  19. self.offload_params_by_group = tuple(
  20. [torch.nn.Parameter(torch.empty_like(param, device=offload_device, dtype=offload_dtype),
  21. requires_grad=param.requires_grad)
  22. for param in group["params"]] for group in param_groups)
  23. for group, offload_params in zip(param_groups, self.offload_params_by_group):
  24. for param, offload_param in zip(group['params'], offload_params):
  25. offload_param.copy_(param, non_blocking=True)
  26. if offload_param.grad is None:
  27. offload_param.grad = torch.zeros_like(offload_param)
  28. if param.grad is not None:
  29. offload_param.grad.copy_(param.grad, non_blocking=True)
  30. @contextlib.contextmanager
  31. def _use_offloaded_params(self, *,
  32. sync_params_before: bool, sync_grads_before: bool,
  33. sync_params_after: bool, sync_grads_after: bool):
  34. assert len(self.param_groups) == len(self.offload_params_by_group)
  35. original_params_per_group = [group["params"] for group in self.param_groups]
  36. with self.lock:
  37. try:
  38. with torch.no_grad():
  39. for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
  40. for original_param, replacement_param in zip(original_params, replacement_params):
  41. if sync_params_before:
  42. replacement_param.copy_(original_param, non_blocking=True)
  43. if sync_grads_before and original_param.grad is not None:
  44. replacement_param.grad.copy_(original_param.grad, non_blocking=True)
  45. for group, replacement_params in zip(self.param_groups, self.offload_params_by_group):
  46. group["params"] = replacement_params
  47. yield self.param_groups
  48. finally:
  49. for group, original_params in zip(self.param_groups, original_params_per_group):
  50. group["params"] = original_params
  51. with torch.no_grad():
  52. for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
  53. for original_param, replacement_param in zip(original_params, replacement_params):
  54. if sync_params_after:
  55. original_param.copy_(replacement_param, non_blocking=True)
  56. if sync_grads_after and original_param.grad is not None:
  57. original_param.grad.copy_(replacement_param.grad)
  58. def add_param_group(self, param_group: dict) -> None:
  59. raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.")
  60. def step(self, closure=None, *args, **kwargs):
  61. assert closure is None, "closure not supported in cpu offload mode"
  62. with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=True,
  63. sync_params_after=True, sync_grads_after=self.full_sync):
  64. return self.optim.step(*args, **kwargs)
  65. def zero_grad(self, set_to_none: bool = False, *args, **kwargs):
  66. if not self.full_sync:
  67. torch.optim.Optimizer.zero_grad(self, set_to_none)
  68. with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
  69. sync_params_after=self.full_sync, sync_grads_after=self.full_sync):
  70. return super().zero_grad(*args, set_to_none=False, **kwargs)
  71. def state_dict(self):
  72. with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
  73. sync_params_after=False, sync_grads_after=False):
  74. return self.optim.state_dict()
  75. def load_state_dict(self, state_dict: dict) -> None:
  76. with self._use_offloaded_params(sync_params_before=False, sync_grads_before=False,
  77. sync_params_after=True, sync_grads_after=self.full_sync):
  78. return self.optim.load_state_dict(state_dict)