Bladeren bron

Initial commit (ru-max branch without private code)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Max Ryabinin 4 jaren geleden
commit
72fc0bcdb7
20 gewijzigde bestanden met toevoegingen van 1708 en 0 verwijderingen
  1. 83 0
      .gitignore
  2. 5 0
      README.md
  3. 187 0
      arguments.py
  4. 138 0
      callback.py
  5. 171 0
      huggingface_auth.py
  6. 2 0
      lib/__init__.py
  7. 0 0
      lib/training/__init__.py
  8. 14 0
      lib/training/clipped_lamb.py
  9. 87 0
      lib/training/hf_trainer.py
  10. 93 0
      lib/training/offload.py
  11. 231 0
      lib/training/tpu.py
  12. 47 0
      lib/training/wrapper.py
  13. 9 0
      requirements.txt
  14. 140 0
      run_aux_peer.py
  15. 59 0
      run_trainer.py
  16. 91 0
      run_trainer_tpu.py
  17. 126 0
      task.py
  18. 83 0
      tests/test_ffn.py
  19. 70 0
      tests/test_rotary.py
  20. 72 0
      utils.py

+ 83 - 0
.gitignore

@@ -0,0 +1,83 @@
+# node and NPM
+npm-debug.log
+node_modules
+
+# swap files
+*~
+*.swp
+
+examples/data/*
+examples/runs/*
+examples/.ipynb_checkpoints/*
+
+env.sh
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+bin/
+build/
+develop-eggs/
+dist/
+eggs/
+lib64/
+parts/
+sdist/
+var/
+*.egg-info/
+.installed.cfg
+*.egg/
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.cache
+nosetests.xml
+coverage.xml
+
+# Translations
+*.mo
+
+# Mr Developer
+.mr.developer.cfg
+.project
+.pydevproject
+.idea
+.ipynb_checkpoints
+
+# Rope
+.ropeproject
+
+# Django stuff:
+*.log
+*.pot
+
+# Sphinx documentation
+docs/_build/
+docs/tmp*
+
+# OS X garbage
+.DS_Store
+
+# Debian things
+debian/reproducible-experiment-platform
+debian/files
+*.substvars
+*.debhelper.log
+
+# protobuf stuff
+hivemind/proto/*_pb2*
+
+# libp2p-daemon binary
+hivemind/hivemind_cli/p2pd

+ 5 - 0
README.md

@@ -0,0 +1,5 @@
+## TODO write a setup instruction
+
+Note: You might want to have several initial peers so that if one dies,
+    new workers still can join the collaboration via alive initial peers' addresses.
+    Specify initial_peers argument for that purpose

+ 187 - 0
arguments.py

@@ -0,0 +1,187 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+from transformers import TrainingArguments
+
+
+@dataclass
+class HFTrainerArguments(TrainingArguments):
+    """Arguments for huggingface/transformers.Trainer"""
+    dataloader_num_workers: int = 1
+    per_device_train_batch_size: int = 1
+    per_device_eval_batch_size: int = 1
+    gradient_accumulation_steps: int = 1
+    seq_length: int = 512
+    pad_to_multiple_of: int = 8
+
+    learning_rate: float = 0.0025
+    total_steps: int = 31250  # total number of collaborative SGD updates, used for learning rate schedule
+    warmup_steps: int = 3125
+    adam_epsilon: float = 1e-6
+    weight_decay: float = 0.01
+    max_grad_norm: float = 1.0
+    clamp_value: float = 10000.0
+
+    fp16: bool = False
+    fp16_opt_level: str = "O2"
+    do_train: bool = True
+
+    logging_steps: int = 100
+    max_steps: int = 10 ** 20
+    save_steps: int = 10 ** 20
+    save_total_limit: int = 2
+
+    output_dir: str = "outputs"
+
+    @property
+    def batch_size_per_step(self):
+        """Compute the number of training sequences contributed by each .step() from this peer"""
+        total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps
+        if torch.cuda.device_count() > 0:
+            total_batch_size_per_step *= torch.cuda.device_count()
+        return total_batch_size_per_step
+
+
+@dataclass
+class TPUTrainerArguments(HFTrainerArguments):
+    num_tpus: int = 8  # the total number of TPU cores in use
+    wandb_project: str = "huggingface"
+
+    @property
+    def batch_size_per_step(self):
+        """Compute the number of training sequences contributed by each .step() from this peer"""
+        return self.per_device_train_batch_size * self.gradient_accumulation_steps * self.num_tpus
+
+
+@dataclass
+class CollaborativeArguments:
+    """Configuration for CollaborativeOptimzier and its internals"""
+    target_batch_size: int = field(
+        default=16384,
+        metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
+    )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
+    bandwidth: float = field(
+        default=100.0,
+        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
+    )
+    averaging_expiration: float = field(
+        default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
+    )
+    averaging_timeout: float = field(
+        default=120.0, metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    min_refresh_period: float = field(
+        default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
+    )
+    max_refresh_period: float = field(
+        default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
+    )
+    default_refresh_period: float = field(
+        default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
+    )
+    expected_drift_peers: float = field(
+        default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
+    )
+    expected_drift_rate: float = field(
+        default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
+    )
+    performance_ema_alpha: float = field(
+        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
+    )
+    metadata_expiration: float = field(
+        default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+    )
+    reuse_grad_buffers: bool = field(default=True, metadata={
+        "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
+                "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
+                "advanced techniques (e.g. applying custom loss scaler)"})
+
+
+@dataclass
+class BasePeerArguments:
+    """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
+    experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
+    model_config_path: Optional[str] = field(default="./model.json", metadata={"help": "Path to the model config"})
+    tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"})
+    cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
+
+    authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
+    client_mode: bool = field(
+        default=False,
+        metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
+    )
+    initial_peers: List[str] = field(
+        default_factory=list,
+        metadata={
+            "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
+        },
+    )
+    use_ipfs: bool = field(
+        default=False,
+        metadata={
+            "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
+            "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
+        },
+    )
+    host_maddrs: List[str] = field(
+        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
+        metadata={
+            "help": "Multiaddrs to listen for external connections from other p2p instances. "
+            "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
+        },
+    )
+    announce_maddrs: List[str] = field(
+        default_factory=list,
+        metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
+    )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
+
+
+@dataclass
+class TrainingPeerArguments(BasePeerArguments):
+    statistics_expiration: float = field(
+        default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
+    )
+    backup_every_steps: Optional[int] = field(
+        default=None, metadata={"help": "Update training state backup on disk once in this many global steps "
+                                        "(default = do not update local state)"}
+    )
+    state_path: str = field(
+        default="state.zip", metadata={"help": "Load this state upon init and when recovering from NaN parameters"})
+
+
+@dataclass
+class AuxiliaryPeerArguments(BasePeerArguments):
+    """
+    Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking
+    learning curves, assisting in all-reduce and uploading checkpoints to the hub
+    """
+    refresh_period: float = field(default=10, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
+    wandb_project: Optional[str] = field(
+        default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
+    )
+    save_checkpoint_step_interval: int = field(
+        default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
+    )
+    repo_url: Optional[str] = field(
+        default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
+    )
+    local_path: Optional[str] = field(
+        default="Repo", metadata={"help": "Path to local repository to store the model and optimizer states"}
+    )
+    upload_interval: Optional[float] = field(
+        default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
+    )
+    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
+    assist_in_averaging: bool = field(
+        default=False, metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"})
+    assist_refresh: float = field(default=1.0, metadata={"help": "Period (in seconds) for tryin to assist averaging"})

+ 138 - 0
callback.py

@@ -0,0 +1,138 @@
+import os.path
+from typing import Any
+
+import hivemind
+import torch
+import transformers
+from transformers import TrainingArguments
+
+from arguments import TrainingPeerArguments
+from task import TrainingTask
+from utils import LocalMetrics, logger
+
+
+class CollaborativeCallback(transformers.TrainerCallback):
+    """
+    This callback monitors and reports collaborative training progress,
+    In case of a catastrophic failure, it can also revert training to a backup
+    """
+
+    def __init__(self, task: TrainingTask, args: TrainingPeerArguments):
+        super().__init__()
+        self.task = task
+        self.dht, self.collaborative_optimizer = task.dht, task.collaborative_optimizer
+        self.statistics_expiration = args.statistics_expiration
+        self.last_reported_collaboration_step = -1
+        self.samples = 0
+        self.steps = 0
+        self.loss = 0
+        self.total_samples_processed = 0
+        self.backup_every_steps = args.backup_every_steps
+        self.state_path = args.state_path
+
+    def on_train_begin(
+        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
+    ):
+        if os.path.isfile(self.state_path):
+            self.restore_from_backup(self.state_path)
+            logger.info("Loaded state")
+
+        logger.info("Loading state from peers")
+        self.collaborative_optimizer.load_state_from_peers()
+
+        if os.path.isfile(self.state_path):
+            self.restore_from_backup(self.state_path, check_step=True)
+
+    def on_step_end(
+        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
+    ):
+        control.should_log = True
+        if not self.params_are_finite():
+            if not os.path.exists(self.state_path):
+                raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.")
+            logger.warning("Parameters are invalid, reloading model from earlier state")
+            self.restore_from_backup(self.state_path)
+            return control
+
+        if state.log_history:
+            self.loss += state.log_history[-1]["loss"]
+            self.steps += 1
+            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+                self.total_samples_processed += self.samples
+                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                statistics = LocalMetrics(
+                    step=self.collaborative_optimizer.local_step,
+                    samples_per_second=samples_per_second,
+                    samples_accumulated=self.samples,
+                    loss=self.loss,
+                    mini_steps=self.steps,
+                )
+                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                logger.info(f"Performance: {samples_per_second} samples per second.")
+                if self.steps:
+                    logger.info(f"Local loss: {self.loss / self.steps}")
+
+                self.loss = 0
+                self.steps = 0
+                if self.collaborative_optimizer.is_synchronized:
+                    self.dht.store(
+                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        subkey=self.task.local_public_key,
+                        value=statistics.dict(),
+                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        return_future=True,
+                    )
+                if self.backup_every_steps is not None and \
+                        self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    self.backup_state()
+
+        self.samples = self.collaborative_optimizer.local_samples_accumulated
+
+        return control
+
+    @torch.no_grad()
+    def params_are_finite(self):
+        for param in self.task.model.parameters():
+            if not torch.all(torch.isfinite(param)):
+                return False
+        return True
+
+    @torch.no_grad()
+    def backup_state(self) -> Any:
+        logger.info("Saving backup")
+        return torch.save(
+            {
+                "model": self.task.model.state_dict(),
+                "training": self.collaborative_optimizer.state_dict(),
+                "scheduler": self.collaborative_optimizer.scheduler.state_dict(),
+                "local_step": self.collaborative_optimizer.local_step,
+            },
+            self.state_path,
+        )
+
+    @torch.no_grad()
+    def restore_from_backup(self, path, check_step=False):
+        state = torch.load(path)
+        current_step = self.collaborative_optimizer.local_step
+        backup_step = state['training']['state'][0]['step'] #TODO FIX THIS, use state['local_step']
+        if not check_step or backup_step >= current_step:
+            if (
+                "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
+                in state["model"]
+            ):
+                del state["model"][
+                    "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
+                ]
+                del state["model"][
+                    "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.sin"
+                ]
+            if "scheduler" in state:
+                self.collaborative_optimizer.scheduler.load_state_dict(state["scheduler"])
+            self.collaborative_optimizer.load_state_dict(state["training"])
+            self.collaborative_optimizer.averager.local_step = backup_step
+            self.task.model.load_state_dict(state["model"], strict=False)
+            logger.info("Restored from a backup")
+        else:
+            logger.info("Bypassed restoring state from local backup: backup state is too old.")

+ 171 - 0
huggingface_auth.py

@@ -0,0 +1,171 @@
+import base64
+import os
+import time
+from datetime import datetime, timedelta
+from getpass import getpass
+
+import requests
+from huggingface_hub import HfApi
+
+from hivemind.proto.auth_pb2 import AccessToken
+from hivemind.utils.auth import TokenAuthorizerBase
+from hivemind.utils.crypto import RSAPublicKey
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger("root." + __name__)
+
+
+class NonRetriableError(Exception):
+    pass
+
+
+def call_with_retries(func, n_retries=10, initial_delay=1.0):
+    for i in range(n_retries):
+        try:
+            return func()
+        except NonRetriableError:
+            raise
+        except Exception as e:
+            if i == n_retries - 1:
+                raise
+
+            delay = initial_delay * (2 ** i)
+            logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
+            time.sleep(delay)
+
+
+class InvalidCredentialsError(NonRetriableError):
+    pass
+
+
+class NotInAllowlistError(NonRetriableError):
+    pass
+
+
+class HuggingFaceAuthorizer(TokenAuthorizerBase):
+    _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
+
+    def __init__(self, experiment_id: int, username: str, password: str):
+        super().__init__()
+
+        self.experiment_id = experiment_id
+        self.username = username
+        self.password = password
+
+        self._authority_public_key = None
+        self.coordinator_ip = None
+        self.coordinator_port = None
+
+        self._hf_api = HfApi()
+
+    async def get_token(self) -> AccessToken:
+        """
+        Hivemind calls this method to refresh the token when necessary.
+        """
+
+        self.join_experiment()
+        return self._local_access_token
+
+    def join_experiment(self) -> None:
+        call_with_retries(self._join_experiment)
+
+    def _join_experiment(self) -> None:
+        try:
+            token = self._hf_api.login(self.username, self.password)
+        except requests.exceptions.HTTPError as e:
+            if e.response.status_code == 401:  # Unauthorized
+                raise InvalidCredentialsError()
+            raise
+
+        try:
+            url = f'{self._AUTH_SERVER_URL}/api/experiments/join/{self.experiment_id}/'
+            headers = {'Authorization': f'Bearer {token}'}
+            response = requests.put(url, headers=headers, json={
+                'experiment_join_input': {
+                    'peer_public_key': self.local_public_key.to_bytes().decode(),
+                },
+            })
+
+            response.raise_for_status()
+            response = response.json()
+
+            self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
+            self.coordinator_ip = response['coordinator_ip']
+            self.coordinator_port = response['coordinator_port']
+
+            token_dict = response['hivemind_access']
+            access_token = AccessToken()
+            access_token.username = token_dict['username']
+            access_token.public_key = token_dict['peer_public_key'].encode()
+            access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
+            access_token.signature = token_dict['signature'].encode()
+            self._local_access_token = access_token
+            logger.info(f'Access for user {access_token.username} '
+                        f'has been granted until {access_token.expiration_time} UTC')
+        except requests.exceptions.HTTPError as e:
+            if e.response.status_code == 401:  # Unauthorized
+                raise NotInAllowlistError()
+            raise
+        finally:
+            self._hf_api.logout(token)
+
+    def is_token_valid(self, access_token: AccessToken) -> bool:
+        data = self._token_to_bytes(access_token)
+        if not self._authority_public_key.verify(data, access_token.signature):
+            logger.exception('Access token has invalid signature')
+            return False
+
+        try:
+            expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        except ValueError:
+            logger.exception(
+                f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time.tzinfo is not None:
+            logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time < datetime.utcnow():
+            logger.exception('Access token has expired')
+            return False
+
+        return True
+
+    _MAX_LATENCY = timedelta(minutes=1)
+
+    def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
+        expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        return expiration_time < datetime.utcnow() + self._MAX_LATENCY
+
+    @staticmethod
+    def _token_to_bytes(access_token: AccessToken) -> bytes:
+        return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
+
+
+def authorize_with_huggingface() -> HuggingFaceAuthorizer:
+    while True:
+        experiment_id = os.getenv('HF_EXPERIMENT_ID')
+        if experiment_id is None:
+            experiment_id = input('HuggingFace experiment ID: ')
+
+        username = os.getenv('HF_USERNAME')
+        if username is None:
+            while True:
+                username = input('HuggingFace username: ')
+                if '@' not in username:
+                    break
+                print('Please enter your Huggingface _username_ instead of the email address!')
+
+        password = os.getenv('HF_PASSWORD')
+        if password is None:
+            password = getpass('HuggingFace password: ')
+
+        authorizer = HuggingFaceAuthorizer(experiment_id, username, password)
+        try:
+            authorizer.join_experiment()
+            return authorizer
+        except InvalidCredentialsError:
+            print('Invalid username or password, please try again')
+        except NotInAllowlistError:
+            print('This account is not specified in the allowlist. '
+                  'Please ask a moderator to add you to the allowlist and try again')

+ 2 - 0
lib/__init__.py

@@ -0,0 +1,2 @@
+from .modules import *
+from .models import *

+ 0 - 0
lib/training/__init__.py


+ 14 - 0
lib/training/clipped_lamb.py

@@ -0,0 +1,14 @@
+import torch
+from torch_optimizer import Lamb
+
+
+class LambWithGradientClipping(Lamb):
+    """ A version of LAMB that clips gradients based on their norm. """
+    def __init__(self, *args, max_grad_norm: float, **kwargs):
+        self.max_grad_norm = max_grad_norm
+        super().__init__(*args, **kwargs)
+
+    def step(self, *args, **kwargs):
+        iter_params = (param for group in self.param_groups for param in group['params'])
+        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
+        return super().step(*args, **kwargs)

+ 87 - 0
lib/training/hf_trainer.py

@@ -0,0 +1,87 @@
+"""A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from transformers.trainer import Trainer
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from lib.staging.collaborative import CollaborativeOptimizer
+from lib.staging.scaler import HivemindGradScaler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+
+
+class CollaborativeHFTrainer(Trainer):
+    """
+    A version of HuggingFace trainer that shuffles the dataset using a separate random seed.
+    Used to ensure that peers don't process batches in the same order.
+    """
+
+    def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
+        self.data_seed = data_seed
+        self.collaborative_optimizer = collaborative_optimizer
+        super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
+
+        if self.fp16_backend is not None:
+            assert self.use_amp
+            self.scaler = HivemindGradScaler()
+
+    def get_train_dataloader(self) -> DataLoader:
+        """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
+        torch.manual_seed(self.data_seed)
+        return super().get_train_dataloader()
+
+    def _wrap_model(self, model, training=True):
+        # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
+        return IgnoreGradManipulations(super()._wrap_model(model, training=training),
+                                       override_zero_grad=self.collaborative_optimizer.reuse_grad_buffers)
+
+
+class NoOpScheduler(LRSchedulerBase):
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
+
+    def get_lr(self):
+        return [group['lr'] for group in self.optimizer.param_groups]
+
+    def print_lr(self, *args, **kwargs):
+        if self.optimizer.scheduler:
+            return self.optimizer.scheduler.print_lr(*args, **kwargs)
+
+    def step(self):
+        logger.debug("Called NoOpScheduler.step")
+        self._last_lr = self.get_lr()
+
+    def state_dict(self):
+        return {}
+
+    def load_state_dict(self, *args, **kwargs):
+        logger.debug("Called NoOpScheduler.load_state_dict")
+
+
+class IgnoreGradManipulations(nn.Module):
+    """ Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad) """
+    def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True):
+        super().__init__()
+        self.module = module
+        self.override_clipping = override_clipping
+        self.override_zero_grad = override_zero_grad
+
+    def forward(self, *args, **kwargs):
+        return self.module.forward(*args, **kwargs)
+
+    def zero_grad(self, set_to_none: bool = False) -> None:
+        if self.override_zero_grad and all(param.grad.isfinite().all() for param in self.parameters()):
+            logger.debug("Successfully bypassed zero_grad")
+        else:
+            self.module.zero_grad(set_to_none=set_to_none)
+
+    def clip_grad_norm_(self, max_norm: float, norm_type: int = 2):
+        """ ignore clip_grad_norm on each step, clip in optimizer instead """
+        if self.override_clipping:
+            logger.debug("Successfully bypassed clip_grad_norm_")
+        else:
+            return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type)

+ 93 - 0
lib/training/offload.py

@@ -0,0 +1,93 @@
+import contextlib
+from typing import Type, Iterable, Dict, Union, Optional
+import multiprocessing as mp
+
+import torch
+
+from .wrapper import OptimizerWrapper
+
+
+class OffloadOptimizer(OptimizerWrapper):
+    r""" A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM). """
+
+    def __init__(
+            self, param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]],
+            optim_cls: Type[torch.optim.Optimizer],  *args, full_sync: bool = True,
+            offload_device=torch.device('cpu'), offload_dtype: Optional[torch.dtype] = None, **kwargs):
+        param_groups = list(param_groups)
+        if not isinstance(param_groups[0], dict):
+            param_groups = [{'params': param_groups}]
+        super().__init__(optim_cls(param_groups, *args, **kwargs))
+        self.full_sync = full_sync
+        self.lock = mp.Lock()
+
+        with torch.no_grad():
+            self.offload_params_by_group = tuple(
+                [torch.nn.Parameter(torch.empty_like(param, device=offload_device, dtype=offload_dtype),
+                                    requires_grad=param.requires_grad)
+                 for param in group["params"]] for group in param_groups)
+
+            for group, offload_params in zip(param_groups, self.offload_params_by_group):
+                for param, offload_param in zip(group['params'], offload_params):
+                    offload_param.copy_(param, non_blocking=True)
+                    if offload_param.grad is None:
+                        offload_param.grad = torch.zeros_like(offload_param)
+                    if param.grad is not None:
+                        offload_param.grad.copy_(param.grad, non_blocking=True)
+
+    @contextlib.contextmanager
+    def _use_offloaded_params(self, *,
+                              sync_params_before: bool, sync_grads_before: bool,
+                              sync_params_after: bool, sync_grads_after: bool):
+        assert len(self.param_groups) == len(self.offload_params_by_group)
+        original_params_per_group = [group["params"] for group in self.param_groups]
+        with self.lock:
+            try:
+                with torch.no_grad():
+                    for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
+                        for original_param, replacement_param in zip(original_params, replacement_params):
+                            if sync_params_before:
+                                replacement_param.copy_(original_param, non_blocking=True)
+                            if sync_grads_before and original_param.grad is not None:
+                                replacement_param.grad.copy_(original_param.grad, non_blocking=True)
+
+                for group, replacement_params in zip(self.param_groups, self.offload_params_by_group):
+                    group["params"] = replacement_params
+                yield self.param_groups
+            finally:
+                for group, original_params in zip(self.param_groups, original_params_per_group):
+                    group["params"] = original_params
+
+                with torch.no_grad():
+                    for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group):
+                        for original_param, replacement_param in zip(original_params, replacement_params):
+                            if sync_params_after:
+                                original_param.copy_(replacement_param, non_blocking=True)
+                            if sync_grads_after and original_param.grad is not None:
+                                original_param.grad.copy_(replacement_param.grad)
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.")
+
+    def step(self, closure=None, *args, **kwargs):
+        assert closure is None, "closure not supported in cpu offload mode"
+        with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=True,
+                                        sync_params_after=True, sync_grads_after=self.full_sync):
+            return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, set_to_none: bool = False, *args, **kwargs):
+        if not self.full_sync:
+            torch.optim.Optimizer.zero_grad(self, set_to_none)
+        with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
+                                        sync_params_after=self.full_sync, sync_grads_after=self.full_sync):
+            return super().zero_grad(*args, set_to_none=False, **kwargs)
+
+    def state_dict(self):
+        with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync,
+                                        sync_params_after=False, sync_grads_after=False):
+            return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        with self._use_offloaded_params(sync_params_before=False, sync_grads_before=False,
+                                        sync_params_after=True, sync_grads_after=self.full_sync):
+            return self.optim.load_state_dict(state_dict)

+ 231 - 0
lib/training/tpu.py

@@ -0,0 +1,231 @@
+import ctypes
+import threading
+from functools import partial
+from contextlib import nullcontext
+from copy import deepcopy
+import multiprocessing as mp
+from itertools import zip_longest
+from typing import Iterable
+
+import torch
+import torch.nn as nn
+import torch.utils.data
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.parallel_loader as pl
+
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class TPUManager(mp.Process):
+    """Auxiliary class that manages model training over an array of TPU cores"""
+
+    def __init__(self,
+                 model,
+                 dataset,
+                 *,
+                 collate_fn: callable = None,
+                 nprocs: int = 8,
+                 prefetch: int = 16,
+                 batch_size_per_device: int = 1,
+                 grad_accumulation_steps: int = 1,
+                 seed_base: int = 42,
+                 start: bool):
+        super().__init__()
+        self.lock = mp.Lock()
+        self.nprocs, self.prefetch, self.seed_base = nprocs, prefetch, seed_base
+        self.batch_size_per_device, self.grad_accumulation_steps = batch_size_per_device, grad_accumulation_steps
+        self.collate_fn = collate_fn
+        self.step_triggered, self.step_finished = mp.Event(), mp.Event()
+        self._synchronizer = TPUSynchronizer(model)
+        self._data_manager = TPUDataManager(dataset, nprocs, prefetch)
+
+        # shared fields for communicating statistics after each step
+        self.should_load_parameters = mp.Value(ctypes.c_bool, False)
+        self.gradients_accumulated = mp.Value(ctypes.c_long, 0)
+        self.loss_accumulated = mp.Value(ctypes.c_double, 0)
+        if start:
+            self.start()
+
+    def run(self):
+        thread = threading.Thread(
+            target=partial(xmp.spawn, self.runner, nprocs=self.nprocs, start_method='fork'),
+            daemon=True)
+        thread.start()
+        thread.join()
+
+    def update_model_parameters(self, new_host_parameters):
+        """Schedule TPUs to update model parameters during at the beginning of the next step"""
+        with self.lock, torch.no_grad():
+            self._synchronizer.set_host_parameters(new_host_parameters)
+            self.should_load_parameters.value = True
+
+    def get_aggregated_gradients(self):
+        """Get current accumulated gradients from the master model"""
+        with self.lock, torch.no_grad():
+            return self._synchronizer.get_aggregated_gradients()
+
+    def zero_grad(self):
+        """Reset master accumulated gradients to zeros"""
+        with self.lock, torch.no_grad():
+            for param in self._synchronizer.master_model.parameters():
+                param.grad.zero_()
+
+    def step(self):
+        """run forward/backward step with all TPUs, collect gradients"""
+        self.loss_accumulated.value = self.gradients_accumulated.value = 0
+        self.step_finished.clear()
+        self.step_triggered.set()
+        self.step_finished.wait()
+        return self.loss_accumulated.value, self.gradients_accumulated.value
+
+    def runner(self, tpu_index):
+        """Run training steps from the perspective of a single TPU core"""
+        # acquire the (unique) Cloud TPU core corresponding to this process's index
+        device = xm.xla_device()
+        logger.info(f"Process {tpu_index} is using {xm.xla_real_devices([str(device)])[0]}")
+
+        # set random seed for
+        torch.manual_seed(self.seed_base + tpu_index)
+
+        # use staged init to minimize peak RAM usage
+        for init_index in range(xm.xrt_world_size()):
+            xm.rendezvous(f'init_{init_index}')
+            if tpu_index == init_index:
+                model = self._synchronizer.get_device_model_replica(device)
+                data_loader = self._data_manager.get_device_dataloader(
+                    batch_size=self.batch_size_per_device, num_workers=0, collate_fn=self.collate_fn, pin_memory=False)
+                data_loader_iter = iter(data_loader)
+                logger.info(f"Process {tpu_index} initialized.")
+
+        xm.rendezvous('init_finished')
+
+        while True:
+            self.step_triggered.wait()
+            xm.rendezvous('before_step')
+            if xm.is_master_ordinal():
+                self.step_triggered.clear()
+
+            if bool(self.should_load_parameters.value):
+                with self.lock if xm.is_master_ordinal() else nullcontext():
+                    self._synchronizer.send_params_to_device(model)
+                    self.should_load_parameters.value = False
+
+            ### compute loss and gradients
+            loss = 0.0
+            for i in range(self.grad_accumulation_steps):
+                inputs = next(data_loader_iter)
+                outputs = model(**inputs)
+                loss_i = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+                loss_i = loss_i / (self.grad_accumulation_steps * self.nprocs)
+                loss_i.backward()
+                loss += loss_i
+                del inputs, outputs, loss_i
+
+            ### aggregate gradients from TPUs
+            with self.lock if xm.is_master_ordinal() else nullcontext():
+                self._synchronizer.aggregate_grads_on_host(model, add=True)
+            # clear aggregated gradients from all devices
+            model.zero_grad()
+
+            ### accumulate statistics to host
+            loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0)
+            xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,))
+
+    def _mark_step_finished(self, loss):
+        self.gradients_accumulated.value = self.batch_size_per_device * self.nprocs * self.grad_accumulation_steps
+        self.loss_accumulated.value = float(loss)
+        self.step_finished.set()
+
+
+class TPUSynchronizer:
+    """An auxiliary class for manipulating parameters and gradients without producing a ton of XLA graphs"""
+
+    def __init__(self, model: nn.Module):
+        self.master_model = model.share_memory()
+        for param in self.master_model.parameters():
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            param.grad = param.grad.share_memory_()
+
+    def get_device_model_replica(self, device: torch.device, tie_weights: bool = True):
+        replica = deepcopy(self.master_model).to(device)
+        if tie_weights:
+            replica.tie_weights()
+        for param in replica.parameters():
+            param.grad = torch.zeros_like(param, device=device)
+        return replica
+
+    def set_host_parameters(self, new_host_parameters):
+        return self._assign(source=self.master_model.parameters(), target=new_host_parameters, add=False, strict=True)
+
+    def get_aggregated_gradients(self):
+        return [param.grad for param in self.master_model.parameters()]
+
+    def send_params_to_device(self, replica: nn.Module):
+        """Copy params from master_model to this device_model replica"""
+        with torch.no_grad():
+            replica_params = list(replica.parameters())
+            master_params = list(self.master_model.parameters())
+            master_params = xm.send_cpu_data_to_device(master_params, xm.xla_device())
+            self._assign(source=master_params, target=replica_params, add=False)
+            xm.rendezvous("params_replicated")
+
+    def aggregate_grads_on_host(self, replica: nn.Module, *, add: bool):
+        """Aggregate grads from all tpu devices and move them to host"""
+        with torch.no_grad():
+            replica_grads = [param.grad for param in replica.parameters()]
+            replica_grads = xm.all_reduce(xm.REDUCE_SUM, replica_grads, scale=1.0)
+            master_grads = [hp.grad for hp in self.master_model.parameters()]
+            xm.do_on_ordinals(lambda *replica_grads: self._assign(source=replica_grads, target=master_grads, add=add),
+                              data=tuple(replica_grads), ordinals=(0,))
+            # ^-- do_on_ordinals already runs rendezvous at the end
+
+    def _assign(self, source: Iterable[torch.Tensor], target: Iterable[torch.Tensor], add: bool, strict: bool = False):
+        for source_tensor, target_tensor in zip_longest(source, target):
+            assert source_tensor is not None or target_tensor is not None, "Source and target length must match exactly"
+            if strict:
+                assert source_tensor.shape == target_tensor.shape
+                assert source_tensor.device == target_tensor.device
+                assert source_tensor.dtype == target_tensor.dtype
+            if add:
+                target_tensor.add_(source_tensor)
+            else:
+                target_tensor.copy_(source_tensor)
+
+
+class TPUDataManager:
+    """An auxiliary class that loads centralized dataset from master into multiple TPU devices"""
+    def __init__(self, dataset: torch.utils.data.Dataset, nprocs: int, master_prefetch: int = 16):
+        self.dataset, self.nprocs = dataset, nprocs
+        self.device_queues = [mp.Queue(master_prefetch) for _ in range(nprocs)]
+        self._loader_thread = threading.Thread(target=self._load_data_into_queues)
+        self._loader_thread.start()
+
+    def _load_data_into_queues(self):
+        try:
+            for i, batch in enumerate(self.dataset):
+                self.device_queues[i % self.nprocs].put(batch)
+        finally:
+            logger.warning("Minibatch generator finished.")
+
+    def get_device_dataloader(self, **kwargs):
+        data_loader = torch.utils.data.DataLoader(QueueDataset(self.device_queues[xm.get_ordinal()]), **kwargs)
+        return pl.ParallelLoader(data_loader, [xm.xla_device()]).per_device_loader(xm.xla_device())
+
+
+class QueueDataset(torch.utils.data.IterableDataset):
+    """A dataset that ceaselessly iterates over a queue"""
+    def __init__(self, queue: mp.Queue):
+        super().__init__()
+        self.queue = queue
+
+    def __iter__(self):
+        while True:
+            yield self.queue.get()
+
+    def __len__(self):
+        return 10 ** 12  # TODO deprecate this when the issue is resolved: https://github.com/googlecolab/colabtools/issues/2237

+ 47 - 0
lib/training/wrapper.py

@@ -0,0 +1,47 @@
+import torch
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    r"""
+    A wrapper for pytorch.optimizer that forwards all methods to the wrapped optimizer
+    """
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        object.__init__(self)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)

+ 9 - 0
requirements.txt

@@ -0,0 +1,9 @@
+transformers>=4.9.2
+tokenizers>=0.10.2
+datasets>=1.11.0
+torch_optimizer>=0.1.0
+wandb>=0.10.33
+nltk>=3.6.2
+sentencepiece
+aiohttp
+requests>=2.24.0

+ 140 - 0
run_aux_peer.py

@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+import threading
+import time
+
+import torch
+import wandb
+from transformers import HfArgumentParser
+from huggingface_hub import HfFolder, Repository
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+import utils
+from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
+from task import TrainingTask
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
+
+
+class CheckpointHandler:
+    def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
+        self.task, self.peer_args = task, peer_args
+        self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
+        self.prefix = peer_args.experiment_prefix
+        self.local_path = peer_args.local_path
+        self.upload_interval = peer_args.upload_interval
+        if self.upload_interval is not None:
+            self.token = HfFolder.get_token()
+            self.repo = Repository(
+                local_dir=self.local_path,
+                clone_from=peer_args.repo_url,
+                use_auth_token=self.token,
+            )
+        self.previous_step = -1
+        self.previous_timestamp = time.time()
+
+    def should_save_state(self, cur_step):
+        if self.save_checkpoint_step_interval is None:
+            return False
+        elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
+            return True
+        else:
+            return False
+
+    def save_state(self, cur_step):
+        logger.info("Saving state from peers")
+        self.task.collaborative_optimizer.load_state_from_peers()
+        self.previous_step = cur_step
+
+    def is_time_to_upload(self):
+        if self.upload_interval is None:
+            return False
+        elif time.time() - self.previous_timestamp >= self.upload_interval:
+            return True
+        else:
+            return False
+
+    def upload_checkpoint(self, current_loss):
+        logger.info("Saving model")
+        self.task.model.save_pretrained(self.local_path)
+        logger.info("Saving optimizer")
+        torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
+        self.previous_timestamp = time.time()
+        logger.info("Started uploading to Model Hub")
+        self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_step}, loss {current_loss:.3f}")
+        logger.info("Finished uploading to Model Hub")
+
+
+def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
+    while True:
+        time.sleep(peer_args.assist_refresh)
+        task.collaborative_optimizer.step_aux()
+
+
+if __name__ == "__main__":
+    parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
+    peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
+
+    task = TrainingTask(peer_args, trainer_args, collab_args)
+    dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
+
+    if peer_args.wandb_project is not None:
+        wandb.init(project=peer_args.wandb_project)
+
+    current_step = 0
+    if peer_args.store_checkpoints:
+        checkpoint_handler = CheckpointHandler(task, peer_args)
+
+    if peer_args.assist_in_averaging:
+        assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
+        averaging_thread = threading.Thread(
+            name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
+        averaging_thread.start()
+
+    while True:
+        metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
+        if metrics_entry is not None and len(metrics_entry.value) > 0:
+            metrics_dict = metrics_entry.value
+            metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
+            latest_step = max(item.step for item in metrics)
+
+            if latest_step != current_step:
+                logger.debug(f"Got metrics from {len(metrics)} peers")
+
+                for i, metrics_for_peer in enumerate(metrics):
+                    logger.debug(f"{i} peer {metrics_for_peer}")
+
+                current_step = latest_step
+                alive_peers = 0
+                sum_loss = 0
+                num_samples = 0
+                sum_perf = 0
+                sum_mini_steps = 0
+
+                for item in metrics:
+                    sum_loss += item.loss
+                    alive_peers += 1
+                    sum_perf += item.samples_per_second
+                    num_samples += item.samples_accumulated
+                    sum_mini_steps += item.mini_steps
+                current_loss = sum_loss / sum_mini_steps
+                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
+
+                if peer_args.wandb_project is not None:
+                    wandb.log(
+                        {
+                            "loss": current_loss,
+                            "alive peers": alive_peers,
+                            "samples": num_samples,
+                            "performance": sum_perf,
+                            "step": latest_step,
+                        }
+                    )
+
+                if peer_args.store_checkpoints:
+                    if checkpoint_handler.should_save_state(current_step):
+                        checkpoint_handler.save_state(current_step)
+                        if checkpoint_handler.is_time_to_upload():
+                            checkpoint_handler.upload_checkpoint(current_loss)
+        logger.debug("Peer is still alive...")
+        time.sleep(peer_args.refresh_period)

+ 59 - 0
run_trainer.py

@@ -0,0 +1,59 @@
+#!/usr/bin/env python
+
+import os
+from pathlib import Path
+
+
+import transformers
+from transformers import HfArgumentParser
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from lib.training.hf_trainer import CollaborativeHFTrainer
+
+import callback
+import utils
+from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments
+from task import TrainingTask
+
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
+
+
+def main():
+    parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
+    training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
+
+    logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
+    if len(training_peer_args.initial_peers) == 0:
+        logger.warning("Please specify at least one network endpoint in initial peers.")
+
+    utils.setup_logging(trainer_args)
+    task = TrainingTask(training_peer_args, trainer_args, collab_args)
+    model = task.model.to(trainer_args.device)
+
+    collaborative_callback = callback.CollaborativeCallback(task, training_peer_args)
+    assert trainer_args.do_train and not trainer_args.do_eval
+
+    # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
+    # This is done because collaborative training has its own callbacks that take other peers into account.
+    trainer = CollaborativeHFTrainer(
+        model=model,
+        args=trainer_args,
+        tokenizer=task.tokenizer,
+        data_collator=task.data_collator,
+        data_seed=hash(task.local_public_key),
+        train_dataset=task.training_dataset,
+        eval_dataset=None,
+        collaborative_optimizer=task.collaborative_optimizer,
+        callbacks=[collaborative_callback],
+    )
+    trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
+    trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
+
+    latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
+    trainer.train(model_path=latest_checkpoint_dir)
+
+
+if __name__ == "__main__":
+    main()

+ 91 - 0
run_trainer_tpu.py

@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+import time
+
+import wandb
+import torch
+import transformers.training_args
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from transformers import HfArgumentParser
+
+import utils
+from arguments import TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments
+from lib.training.tpu import TPUManager
+from callback import CollaborativeCallback
+from task import TrainingTask
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
+
+transformers.training_args.is_torch_tpu_available = lambda: False  # disable builtin TPU support to use custom code
+
+
+def main():
+    parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments))
+    peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
+
+    logger.info(f"Found {len(peer_args.initial_peers)} initial peers: {peer_args.initial_peers}")
+    if len(peer_args.initial_peers) == 0:
+        logger.warning("Please specify at least one network endpoint in initial peers.")
+
+    utils.setup_logging(trainer_args)
+    task = TrainingTask(peer_args, trainer_args, collab_args)
+    model = task.model
+
+    # BEGIN init TPU
+    assert trainer_args.do_train and not trainer_args.do_eval
+    tpu_manager = TPUManager(model, dataset=task.training_dataset, collate_fn=task.data_collator,
+                             grad_accumulation_steps=trainer_args.gradient_accumulation_steps,
+                             batch_size_per_device=trainer_args.per_device_train_batch_size,
+                             nprocs=trainer_args.n_tpus, start=True)
+
+    model = task.model = tpu_manager._synchronizer.master_model
+
+    # warmup tpus
+    logger.info("Waiting for TPUs to warm up, this may take a minute...")
+    tpu_manager.step()
+    logger.info("Warmup step 1 / 3 done.")
+    tpu_manager.update_model_parameters(model.parameters())
+    tpu_manager.step()
+    logger.info("Warmup step 2 / 3 done.")
+    tpu_manager.step()
+    tpu_manager.get_aggregated_gradients()
+    tpu_manager.zero_grad()
+    logger.info("Warmup step 3 / 3 done.")
+    # END init TPU
+
+    def push_params_onto_tpu():
+        logger.info("Pushing new params onto TPU.")
+        tpu_manager.update_model_parameters(model.parameters())
+        tpu_manager.zero_grad()
+
+    collaborative_optimizer = task.collaborative_optimizer
+    collaborative_optimizer.callbacks.on_after_global_step.add(push_params_onto_tpu)
+    collaborative_optimizer.callbacks.on_load_state_from_peers(push_params_onto_tpu)
+
+    collaborative_training_callback = CollaborativeCallback(task, peer_args)
+
+    state = transformers.TrainerState()
+    control = transformers.TrainerControl()
+    collaborative_training_callback.on_train_begin(trainer_args, state, control)
+    tpu_manager.update_model_parameters(model.parameters())
+
+    wandb.init(project=trainer_args.wandb_project, name=trainer_args.run_name)
+
+    while True:
+        start_time = time.perf_counter()
+        loss, num_accumulated = tpu_manager.step()
+        time_delta = time.perf_counter() - start_time
+        logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.")
+        wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.scheduler.get_lr()[0]})
+
+        with torch.no_grad():
+            for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()):
+                param.grad[...] = grad_from_tpu
+            collaborative_optimizer.step()
+
+        state.log_history.append(dict(loss=loss))
+        collaborative_training_callback.on_step_end(trainer_args, state, control)
+
+
+if __name__ == "__main__":
+    main()

+ 126 - 0
task.py

@@ -0,0 +1,126 @@
+import os
+from dataclasses import asdict
+from pathlib import Path
+
+import hivemind
+import transformers
+from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
+from transformers import AlbertTokenizerFast, get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
+
+import utils
+from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
+from data import make_dataset
+from huggingface_auth import authorize_with_huggingface
+from lib import LeanAlbertConfig, LeanAlbertForPreTraining
+from lib.staging.collaborative import CollaborativeOptimizer
+from lib.training.clipped_lamb import LambWithGradientClipping
+from lib.training.offload import OffloadOptimizer
+
+hivemind.use_hivemind_log_handler("in_root_logger")
+logger = hivemind.get_logger()
+
+
+class TrainingTask:
+    """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
+    _dht = _collaborative_optimizer = _training_dataset = None
+
+
+    def __init__(
+            self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
+        self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
+        self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
+        transformers.set_seed(trainer_args.seed)  # seed used for initialization
+        self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
+        self.tokenizer = AlbertTokenizerFast.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
+
+        output_dir = Path(trainer_args.output_dir)
+        logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
+        latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
+
+        if latest_checkpoint_dir is None:
+            logger.info(f"Creating model")
+            self.model = LeanAlbertForPreTraining(self.config)
+            self.model.resize_token_embeddings(len(self.tokenizer))
+        else:
+            logger.info(f"Loading model from {latest_checkpoint_dir}")
+            self.model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
+
+    @property
+    def dht(self):
+        if self._dht is None:
+            self._dht = hivemind.DHT(
+                start=True,
+                initial_peers=self.peer_args.initial_peers,
+                client_mode=self.peer_args.client_mode,
+                host_maddrs=self.peer_args.host_maddrs,
+                announce_maddrs=self.peer_args.announce_maddrs,
+                use_ipfs=self.peer_args.use_ipfs,
+                record_validators=self.validators,
+                identity_path=self.peer_args.identity_path,
+                authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
+            )
+            if self.peer_args.client_mode:
+                logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
+            else:
+                utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
+        return self._dht
+
+    @property
+    def collaborative_optimizer(self):
+        if self._collaborative_optimizer is None:
+            opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
+            averaging_compression = SizeAdaptiveCompression(
+                threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
+            state_compression = hivemind.Float16Compression()
+            self._collaborative_optimizer = CollaborativeOptimizer(
+                dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
+                batch_size_per_step=self.trainer_args.batch_size_per_step,
+                compression=averaging_compression, state_compression=state_compression,
+                client_mode=self.peer_args.client_mode, verbose=True, start=True, **asdict(self.collab_args))
+        return self._collaborative_optimizer
+
+    def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
+        no_decay = ["bias", "LayerNorm.weight"]
+        optimizer_grouped_parameters = [
+            {
+                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+                "weight_decay": training_args.weight_decay,
+            },
+            {
+                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+                "weight_decay": 0.0,
+            },
+        ]
+
+        opt = OffloadOptimizer(
+            optimizer_grouped_parameters,
+            optim_cls=LambWithGradientClipping,
+            lr=training_args.learning_rate,
+            betas=(training_args.adam_beta1, training_args.adam_beta2),
+            eps=training_args.adam_epsilon,
+            weight_decay=training_args.weight_decay,
+            max_grad_norm=training_args.max_grad_norm,
+            clamp_value=training_args.clamp_value,
+            debias=True,
+        )
+
+        scheduler = get_linear_schedule_with_warmup(
+            opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
+        )
+
+        return opt, scheduler
+
+    @property
+    def training_dataset(self):
+        if self._training_dataset is None:
+            self._training_dataset = make_dataset(
+                self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
+                max_sequence_length=self.trainer_args.seq_length
+            )
+        return self._training_dataset
+
+    @property
+    def data_collator(self):
+        return DataCollatorForLanguageModeling(
+            tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
+        )

+ 83 - 0
tests/test_ffn.py

@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lib.modules.ffn import LeanFFN
+
+
+class ReferenceFFN(nn.Module):
+
+    def __init__(self,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 activation=F.gelu,
+                 layer_norm_eps=1e-12,
+                 dropout: float = 0.0):
+        super().__init__()
+        self.dense_i2h = nn.Linear(hidden_size, intermediate_size)
+        self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
+        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+        self.activation = activation
+        self.dropout = dropout
+
+    def forward(self, input):
+        output = self.dense_i2h(self.layer_norm(input))
+        output = self.activation(output)
+        output = self.dense_h2o(output)
+        output = F.dropout(output, self.dropout)
+        return output + input
+
+
+def test_ffn_exact_match():
+    torch.use_deterministic_algorithms(True)
+
+    batch_size = 4
+    seq_len = 128
+    dim = 32
+    num_layers = 4
+
+    baseline_ffn = ReferenceFFN(dim, 4 * dim)
+    our_ffn = LeanFFN(dim, 4 * dim)
+
+    assert our_ffn.load_state_dict(baseline_ffn.state_dict())
+
+    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
+
+    # test outputs
+    out_ref = x
+    for i in range(num_layers):
+        out_ref = baseline_ffn.forward(out_ref)
+
+    out_our = x
+    for i in range(num_layers):
+        out_our = our_ffn(out_our)
+
+    assert torch.allclose(out_our, out_ref)
+
+    # test grad inputs
+    obj = (out_ref * (out_ref + 1)).square().mean()
+    grad_ref, = torch.autograd.grad(obj, x)
+
+    obj = (out_our * (out_our + 1)).square().mean()
+    grad_our, = torch.autograd.grad(obj, x)
+    assert torch.allclose(grad_ref, grad_our)
+
+    # test grad params
+    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
+
+    out_ref = x
+    for i in range(num_layers):
+        out_ref = baseline_ffn.forward(out_ref)
+
+    out_our = x
+    for i in range(num_layers):
+        out_our = our_ffn(out_our)
+
+    obj = (out_ref * (out_ref + 1)).square().mean()
+    grad_params_ref = torch.autograd.grad(obj, list(baseline_ffn.parameters()))
+
+    obj = (out_our * (out_our + 1)).square().mean()
+    grad_params_our = torch.autograd.grad(obj, list(our_ffn.parameters()))
+
+    for grad_ref, grad_our in zip(grad_params_ref, grad_params_our):
+        assert torch.allclose(grad_ref, grad_our)

+ 70 - 0
tests/test_rotary.py

@@ -0,0 +1,70 @@
+import torch
+
+from lib.modules.rotary import get_auxiliary_tensors, RotaryEmbeddings
+
+
+def test_rotary_embeddings():
+    batch_size = 11
+    seq_len = 50
+    nhead = 4
+    dim = 1024
+    dtype = torch.float32
+    device = torch.device('cpu')
+    base = 10 ** 4
+
+    torch.use_deterministic_algorithms(True)
+
+    # auxiliary tensors
+    a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
+    x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
+    assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
+    assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
+
+    # full layer outputs
+    ref_layer = Rotary(dim, base)
+    our_layer = RotaryEmbeddings(dim, base)
+    q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
+    k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
+
+    q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
+    q_our, k_our = our_layer(q), our_layer(k)
+    assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
+    assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
+
+    # check rotation equivariance of dot product
+    original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
+    rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
+    assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
+
+
+class Rotary(torch.nn.Module):
+    """ Reference implementation by ElutherAI """
+    def __init__(self, dim, base=10000):
+        super().__init__()
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+        self.seq_len_cached = None
+        self.cos_cached = None
+        self.sin_cached = None
+
+    def forward(self, x, seq_dim=1):
+        seq_len = x.shape[seq_dim]
+        if seq_len != self.seq_len_cached:
+            self.seq_len_cached = seq_len
+            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.cos_cached = emb.cos()[:, None, None, :]
+            self.sin_cached = emb.sin()[:, None, None, :]
+        return self.cos_cached, self.sin_cached
+
+
+def rotate_half(x):
+    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+    return torch.cat(
+        (-x2, x1), dim=x1.ndim - 1
+    )  # dim=-1 triggers a bug in torch < 1.8.0
+
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

+ 72 - 0
utils.py

@@ -0,0 +1,72 @@
+from typing import Dict, List, Tuple
+
+import transformers.utils.logging
+from multiaddr import Multiaddr
+from pydantic import BaseModel, StrictFloat, confloat, conint
+
+from hivemind import choose_ip_address
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
+from hivemind.dht.validation import RecordValidatorBase
+from hivemind.utils.logging import get_logger
+from transformers.trainer_utils import is_main_process
+
+logger = get_logger(__name__)
+
+
+class LocalMetrics(BaseModel):
+    step: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    loss: StrictFloat
+    mini_steps: conint(ge=0, strict=True)
+
+
+class MetricSchema(BaseModel):
+    metrics: Dict[BytesWithPublicKey, LocalMetrics]
+
+
+def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+    signature_validator = RSASignatureValidator()
+    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
+    return validators, signature_validator.local_public_key
+
+
+class TextStyle:
+    BOLD = "\033[1m"
+    BLUE = "\033[34m"
+    RESET = "\033[0m"
+
+
+def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
+    if only_p2p:
+        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
+        initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
+    else:
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
+        if available_ips:
+            preferred_ip = choose_ip_address(available_ips)
+            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
+        else:
+            selected_maddrs = visible_maddrs
+        initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
+
+    logger.info(
+        f"Running a DHT peer. To connect other peers to this one over the Internet, use "
+        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
+    )
+    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")
+
+
+def setup_logging(training_args):
+    # Log on each process the small summary:
+    logger.warning(
+        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+    )
+    # Set the verbosity to info of the Transformers logger (on main process only):
+    if is_main_process(training_args.local_rank):
+        transformers.utils.logging.set_verbosity_info()
+        transformers.utils.logging.enable_default_handler()
+        transformers.utils.logging.enable_explicit_format()
+    logger.info("Training/evaluation parameters %s", training_args)