hf_trainer.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
  2. import torch
  3. from torch import nn
  4. from torch.utils.data import DataLoader
  5. from transformers.trainer import Trainer
  6. from hivemind import CollaborativeOptimizer
  7. from hivemind.optim import HivemindGradScaler
  8. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  9. use_hivemind_log_handler("in_root_logger")
  10. logger = get_logger()
  11. LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
  12. class CollaborativeHFTrainer(Trainer):
  13. """
  14. A version of HuggingFace trainer that shuffles the dataset using a separate random seed.
  15. Used to ensure that peers don't process batches in the same order.
  16. """
  17. def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
  18. self.data_seed = data_seed
  19. self.collaborative_optimizer = collaborative_optimizer
  20. super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
  21. if self.fp16_backend is not None:
  22. assert self.use_amp
  23. self.scaler = HivemindGradScaler()
  24. def get_train_dataloader(self) -> DataLoader:
  25. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  26. torch.manual_seed(self.data_seed)
  27. return super().get_train_dataloader()
  28. def _wrap_model(self, model, training=True):
  29. # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
  30. return IgnoreGradManipulations(super()._wrap_model(model, training=training),
  31. override_zero_grad=self.collaborative_optimizer.reuse_grad_buffers)
  32. class NoOpScheduler(LRSchedulerBase):
  33. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  34. def get_lr(self):
  35. return [group['lr'] for group in self.optimizer.param_groups]
  36. def print_lr(self, *args, **kwargs):
  37. if self.optimizer.scheduler:
  38. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  39. def step(self):
  40. logger.debug("Called NoOpScheduler.step")
  41. self._last_lr = self.get_lr()
  42. def state_dict(self):
  43. return {}
  44. def load_state_dict(self, *args, **kwargs):
  45. logger.debug("Called NoOpScheduler.load_state_dict")
  46. class IgnoreGradManipulations(nn.Module):
  47. """ Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad) """
  48. def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True):
  49. super().__init__()
  50. self.module = module
  51. self.override_clipping = override_clipping
  52. self.override_zero_grad = override_zero_grad
  53. def forward(self, *args, **kwargs):
  54. return self.module.forward(*args, **kwargs)
  55. def zero_grad(self, set_to_none: bool = False) -> None:
  56. if self.override_zero_grad and \
  57. all(param.grad.isfinite().all() for param in self.parameters() if param.requires_grad):
  58. logger.debug("Successfully bypassed zero_grad")
  59. else:
  60. self.module.zero_grad(set_to_none=set_to_none)
  61. def clip_grad_norm_(self, max_norm: float, norm_type: int = 2):
  62. """ ignore clip_grad_norm on each step, clip in optimizer instead """
  63. if self.override_clipping:
  64. logger.debug("Successfully bypassed clip_grad_norm_")
  65. else:
  66. return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type)