123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
- from typing import Optional
- import torch
- from torch import nn
- from torch.utils.data import DataLoader
- from transformers.trainer import Trainer
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
- from lib.staging.collaborative import CollaborativeOptimizer
- from lib.staging.scaler import HivemindGradScaler
- use_hivemind_log_handler("in_root_logger")
- logger = get_logger()
- LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
- class CollaborativeHFTrainer(Trainer):
- """
- A version of HuggingFace trainer that shuffles the dataset using a separate random seed.
- Used to ensure that peers don't process batches in the same order.
- """
- def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
- self.data_seed = data_seed
- self.collaborative_optimizer = collaborative_optimizer
- super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
- if self.fp16_backend is not None:
- assert self.use_amp
- self.scaler = HivemindGradScaler()
- def get_train_dataloader(self) -> DataLoader:
- """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
- torch.manual_seed(self.data_seed)
- return super().get_train_dataloader()
- def _wrap_model(self, model, training=True):
- # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
- return IgnoreGradManipulations(super()._wrap_model(model, training=training),
- override_zero_grad=self.collaborative_optimizer.reuse_grad_buffers)
- class NoOpScheduler(LRSchedulerBase):
- """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
- def get_lr(self):
- return [group['lr'] for group in self.optimizer.param_groups]
- def print_lr(self, *args, **kwargs):
- if self.optimizer.scheduler:
- return self.optimizer.scheduler.print_lr(*args, **kwargs)
- def step(self):
- logger.debug("Called NoOpScheduler.step")
- self._last_lr = self.get_lr()
- def state_dict(self):
- return {}
- def load_state_dict(self, *args, **kwargs):
- logger.debug("Called NoOpScheduler.load_state_dict")
- class IgnoreGradManipulations(nn.Module):
- """ Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad) """
- def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True):
- super().__init__()
- self.module = module
- self.override_clipping = override_clipping
- self.override_zero_grad = override_zero_grad
- def forward(self, *args, **kwargs):
- return self.module.forward(*args, **kwargs)
- def zero_grad(self, set_to_none: bool = False) -> None:
- if self.override_zero_grad and all(param.grad.isfinite().all() for param in self.parameters()):
- logger.debug("Successfully bypassed zero_grad")
- else:
- self.module.zero_grad(set_to_none=set_to_none)
- def clip_grad_norm_(self, max_norm: float, norm_type: int = 2):
- """ ignore clip_grad_norm on each step, clip in optimizer instead """
- if self.override_clipping:
- logger.debug("Successfully bypassed clip_grad_norm_")
- else:
- return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type)
|