hf_trainer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
  2. import json
  3. import urllib
  4. import torch
  5. from torch import nn
  6. from torch.utils.data import DataLoader
  7. from transformers.trainer import Trainer
  8. from hivemind import CollaborativeOptimizer
  9. from hivemind.optim import HivemindGradScaler
  10. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  11. use_hivemind_log_handler("in_root_logger")
  12. logger = get_logger()
  13. LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
  14. URL_IP_INFO = "http://ipinfo.io/json"
  15. class CollaborativeHFTrainer(Trainer):
  16. """
  17. A version of HuggingFace trainer that shuffles the dataset using a separate random seed.
  18. Used to ensure that peers don't process batches in the same order.
  19. """
  20. def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
  21. self.data_seed = data_seed
  22. self.collaborative_optimizer = collaborative_optimizer
  23. args = kwargs["args"]
  24. setattr(args, "run_country", self.get_country_info())
  25. super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
  26. if self.fp16_backend is not None:
  27. assert self.use_amp
  28. self.scaler = HivemindGradScaler()
  29. def get_train_dataloader(self) -> DataLoader:
  30. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  31. torch.manual_seed(self.data_seed)
  32. return super().get_train_dataloader()
  33. def _wrap_model(self, model, training=True):
  34. # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
  35. return IgnoreGradManipulations(super()._wrap_model(model, training=training),
  36. override_zero_grad=self.collaborative_optimizer.grad_averager.reuse_grad_buffers)
  37. def get_country_info(self):
  38. # As this method is only a nice to have, if ever the command fails for any reason we move on to something else
  39. try:
  40. response = urllib.request.urlopen(URL_IP_INFO)
  41. data = json.load(response)
  42. country = data["country"]
  43. except Exception:
  44. country = ""
  45. return country
  46. class NoOpScheduler(LRSchedulerBase):
  47. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  48. def get_lr(self):
  49. return [group['lr'] for group in self.optimizer.param_groups]
  50. def print_lr(self, *args, **kwargs):
  51. if self.optimizer.scheduler:
  52. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  53. def step(self):
  54. logger.debug("Called NoOpScheduler.step")
  55. self._last_lr = self.get_lr()
  56. def state_dict(self):
  57. return {}
  58. def load_state_dict(self, *args, **kwargs):
  59. logger.debug("Called NoOpScheduler.load_state_dict")
  60. class IgnoreGradManipulations(nn.Module):
  61. """ Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad) """
  62. def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True):
  63. super().__init__()
  64. self.module = module
  65. self.override_clipping = override_clipping
  66. self.override_zero_grad = override_zero_grad
  67. def forward(self, *args, **kwargs):
  68. return self.module.forward(*args, **kwargs)
  69. def zero_grad(self, set_to_none: bool = False) -> None:
  70. if self.override_zero_grad and \
  71. all(param.grad.isfinite().all() for param in self.parameters() if param.requires_grad):
  72. logger.debug("Successfully bypassed zero_grad")
  73. else:
  74. self.module.zero_grad(set_to_none=set_to_none)
  75. def clip_grad_norm_(self, max_norm: float, norm_type: int = 2):
  76. """ ignore clip_grad_norm on each step, clip in optimizer instead """
  77. if self.override_clipping:
  78. logger.debug("Successfully bypassed clip_grad_norm_")
  79. else:
  80. return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type)