123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import contextlib
- from typing import Type, Iterable, Dict, Union, Optional
- import multiprocessing as mp
- import torch
- from .wrapper import OptimizerWrapper
- class OffloadOptimizer(OptimizerWrapper):
- r""" A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM). """
- def __init__(
- self, param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]],
- optim_cls: Type[torch.optim.Optimizer], *args, full_sync: bool = True,
- offload_device=torch.device('cpu'), offload_dtype: Optional[torch.dtype] = None, **kwargs):
- param_groups = list(param_groups)
- if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
- super().__init__(optim_cls(param_groups, *args, **kwargs))
- self.full_sync = full_sync
- self.lock = mp.Lock()
- with torch.no_grad():
- self.offload_params_by_group = tuple(
- [torch.nn.Parameter(torch.empty_like(param, device=offload_device, dtype=offload_dtype),
- requires_grad=param.requires_grad)
- for param in group["params"]] for group in param_groups)
- for group, offload_params in zip(param_groups, self.offload_params_by_group):
- for param, offload_param in zip(group['params'], offload_params):
- offload_param.copy_(param, non_blocking=True)
- if offload_param.grad is None:
- offload_param.grad = torch.zeros_like(offload_param)
- if param.grad is not None:
- offload_param.grad.copy_(param.grad, non_blocking=True)
- @contextlib.contextmanager
- def _use_offloaded_params(self, *,
- sync_params_before: bool, sync_grads_before: bool,
- sync_params_after: bool, sync_grads_after: bool):
- assert len(self.param_groups) == len(self.offload_params_by_group)
- original_params_per_group = [group["params"] for group in self.param_groups]
- with self.lock:
- try:
- with torch.no_grad():
- for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
- for original_param, replacement_param in zip(original_params, replacement_params):
- if sync_params_before:
- replacement_param.copy_(original_param, non_blocking=True)
- if sync_grads_before and original_param.grad is not None:
- replacement_param.grad.copy_(original_param.grad, non_blocking=True)
- for group, replacement_params in zip(self.param_groups, self.offload_params_by_group):
- group["params"] = replacement_params
- yield self.param_groups
- finally:
- for group, original_params in zip(self.param_groups, original_params_per_group):
- group["params"] = original_params
- with torch.no_grad():
- for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
- for original_param, replacement_param in zip(original_params, replacement_params):
- if sync_params_after:
- original_param.copy_(replacement_param, non_blocking=True)
- if sync_grads_after and original_param.grad is not None:
- original_param.grad.copy_(replacement_param.grad)
- def add_param_group(self, param_group: dict) -> None:
- raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.")
- def step(self, closure=None, *args, **kwargs):
- assert closure is None, "closure not supported in cpu offload mode"
- with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=True,
- sync_params_after=True, sync_grads_after=self.full_sync):
- return self.optim.step(*args, **kwargs)
- def zero_grad(self, set_to_none: bool = False, *args, **kwargs):
- if not self.full_sync:
- torch.optim.Optimizer.zero_grad(self, set_to_none)
- with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
- sync_params_after=self.full_sync, sync_grads_after=self.full_sync):
- return super().zero_grad(*args, set_to_none=False, **kwargs)
- def state_dict(self):
- with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
- sync_params_after=False, sync_grads_after=False):
- return self.optim.state_dict()
- def load_state_dict(self, state_dict: dict) -> None:
- with self._use_offloaded_params(sync_params_before=False, sync_grads_before=False,
- sync_params_after=True, sync_grads_after=self.full_sync):
- return self.optim.load_state_dict(state_dict)
|