Parcourir la source

Add example for collaborative ALBERT training (#226)

* Create a basic tutorial for running collaborative training with albert
* Add automatic metrics report on WandB

Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Co-authored-by: Roman Zhytar <sir.roma2012@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexey Bukhtiyarov il y a 4 ans
Parent
commit
27ea94e3f9

+ 90 - 0
examples/albert/README.md

@@ -0,0 +1,90 @@
+# Training ALBERT with decentralized averaging
+
+This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the WikiText103 dataset. It uses huggingface [datasets](https://github.com/huggingface/datasets) and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates, using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+
+### Preparation
+* Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
+* Dependencies: `pip install -r requirements.txt`
+* Preprocess data: `python tokenize_wikitext103.py`
+* Upload an archive preprocessed data to somewhere volunteers can reach, example: `https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103_preprocessed.tar`
+
+
+## Running an experiment
+- Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
+   - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
+   - Run `python run_first_peer.py --listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
+   - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
+   - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
+   - This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
+```
++ python ./run_first_peer.py --listen_on '[::]:31209' --experiment_prefix ysda_albert_v10 --wandb_project Demo-run
+[2021/04/19 02:30:06.051][WARN][root.<module>:36] No address specified. Attempting to infer address from DNS.
+[2021/04/19 02:30:06.088][INFO][root.<module>:44] Running DHT root at 18.217.13.97:31209
+wandb: Currently logged in as: ??? (use `wandb login --relogin` to force relogin)
+wandb: Tracking run with wandb version 0.10.26
+wandb: Syncing run wandering-sky-58
+wandb: ⭐ View project at https://wandb.ai/yhn112/Demo-run
+wandb: 🚀 View run at https://wandb.ai/yhn112/Demo-run/runs/38ygvt3n
+wandb: Run data is saved locally in /home/hivemind/examples/albert/wandb/run-20210419_023006-38ygvt3n
+wandb: Run `wandb offline` to turn off syncing.
+[2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
+[2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
+[2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
+```
+
+- To join a collaboration with a GPU trainer, 
+  - install the same dependencies (minus the `wandb` and `whatsmyip`), download the data and unpack it to the experiment folder,
+  - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
+  - run:
+```shell
+ CUDA_VISIBLE_DEVICES=0 HIVEMIND_THREADS=64 python ./hivemind/examples/albert/run_trainer.py \
+ --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
+ --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+```
+Here, `ONE_OR_MORE_PEERS` stands for either your coordinator endpoint (e.g. `123.123.123.123:1337`), an endpoint of any pre-existing trainer or multiple endpoints for stability. See tips & tricks section below for more information on setting up collaborative training.
+
+As the peer begins training, it will periodically report training logs in the following form:
+```
+{'loss': 4.3577, 'learning_rate': 0.001318944, 'epoch': 0.0}
+[...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
+[...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
+[...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
+[...][INFO][optim.collaborative.step:211] Optimizer step: done!
+```
+
+__Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
+
+For convenience, you can view (and share!) the learning curves of your collaborative experiments in wandb:
+![image](https://user-images.githubusercontent.com/3491902/115177859-bed5e100-a0d8-11eb-82bc-55d1b12d335d.png)
+
+
+## Tips and tricks
+
+Finally, we provide best practices for running collaborative experiments of different sizes.
+
+### Hosting the data
+For small experiments (3-16 peers, <1GB data), you can use a free-tier file hosting that has a convenient way to [download with curl/wget](https://superuser.com/questions/470664/how-to-download-dropbox-files-using-wget-command). However, these services are not meant for high load and could ban you for generating too much traffic. If you want to scale up, you could either use an S3-like storage from [any](https://aws.amazon.com/s3/) [cloud](https://cloud.google.com/storage) [provider](https://cloud.google.com/storage) or host the data [yourself]((https://gist.github.com/willurd/5720255)). Large data files (>5GB) will take long to download; we recommend splitting them into chunks and implementing a custom dataloader that can load chunks on the fly. Finally, the most _comme il faut_ solution to sharing large datasets is to use [academic torrents](https://academictorrents.com/).
+ 
+### run_first_peer.py
+This peer exists solely to welcome other peers onto the DHT and track learning progress. It requires neither GPU nor high bandwidth, the only prerequisite is that coordinator should have high uptime. If no high uptime server is available, one can also run multiple coordinators on different servers and list all of them as `--initial_peers`. The system will stay up as long as at least one coordinator is available. For short- to mid-term experiments you can host coordinator on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
+
+### Tuning for hardware/network
+The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should be able to process one local microbatch each `0.5~1` seconds (see trainer's progress bar). To achieve that, we recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. The example trainer supports multiple GPUs via DataParallel. However, using advanced distributed training strategies (e.g. [ZeRO-3](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)) will require changes in `run_trainer.py`.
+
+### Using public GPU providers
+There are awesome services like [Google Colab](https://colab.research.google.com/), [Kaggle kernels](https://www.kaggle.com/dansbecker/running-kaggle-kernels-with-a-gpu) or[Paperspace](https://gradient.paperspace.com/free-gpu) that provide free GPUs. These services usually come with significant limitations (e.g. last gen GPUs, reset every few hours), but they allow just about anyone to join your collaborative experiment. Here's how to best use them.
+  - before you begin, __read the rules carefully__. Most free-tier GPU services allow only one GPU per user and using more than one account will get you banned. It is **your** duty to make sure that collaborators won't get in trouble for helping you.
+  - most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we tested this code on preemptible [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/) nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
+  - you can create starter notebooks to make it more convenient for collaborators to join your training run ([example](https://colab.research.google.com/gist/yhn112/e858cb841c73879d8ef98a84e03b43e7/collaborative-training-v0-10.ipynb)). Ideally, joining collaboration should take at most a couple of clicks.
+
+Here's an example of a full trainer script for Google Colab:
+```
+!pip install transformers datasets sentencepiece torch_optimizer==0.1.0
+!git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
+!curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
+!ulimit -n 4096 && HIVEMIND_THREADS=256 python ./hivemind/examples/albert/run_trainer.py \
+ --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
+ --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
+ --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
+ --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
+```

+ 6 - 0
examples/albert/requirements.txt

@@ -0,0 +1,6 @@
+transformers>=4.5.1
+datasets>=1.5.0
+torch_optimizer>=0.1.0
+wandb>=0.10.26
+sentencepiece
+whatsmyip

+ 67 - 0
examples/albert/run_first_peer.py

@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+
+import time
+import argparse
+import wandb
+from whatsmyip.providers import GoogleDnsProvider
+from whatsmyip.ip import get_ip
+
+import hivemind
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--address', type=str, required=False,
+                        help="this machine's network address. Use public IP for global experiments, "
+                             "local address for private runs.")
+    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
+                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+    parser.add_argument('--refresh_period', type=float, default=30, required=False,
+                        help="coordinator will fetch keys from DHT once in this many seconds")
+    parser.add_argument('--experiment_prefix', type=str, required=True,
+                        help="a prefix where peers store their metrics for aggregation")
+    parser.add_argument('--wandb_project', type=str, required=True,
+                        help="Weights & Biases project name to publish learning curves")
+
+    args = parser.parse_args()
+    if args.address is None:
+        logger.warning("No address specified. Attempting to infer address from DNS.")
+        args.address = get_ip(GoogleDnsProvider)
+
+    dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*")
+    logger.info(f"Running DHT root at {args.address}:{dht.port}")
+
+    wandb.init(project=args.wandb_project)
+    current_step = 0
+
+    while True:
+        metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
+        if metrics_dict is not None:
+            metrics_dict = metrics_dict.value
+            metrics = [metrics_dict[peer].value for peer in metrics_dict]
+            latest_step = max(metrics)[0]
+            if latest_step != current_step:
+                current_step = latest_step
+                alive_peers = 0
+                num_batches = 0
+                sum_loss = 0
+                num_samples = 0
+                sum_perf = 0
+                for step, perf, samples, loss in metrics:
+                    sum_loss += loss
+                    alive_peers += 1
+                    sum_perf += perf
+                    num_samples += samples
+                wandb.log({
+                    "loss": sum_loss / alive_peers,
+                    "alive peers": alive_peers,
+                    "samples": num_samples,
+                    "performance": sum_perf
+                })
+                logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
+        logger.debug("Peer is still alive...")
+        time.sleep(args.refresh_period)

+ 319 - 0
examples/albert/run_trainer.py

@@ -0,0 +1,319 @@
+#!/usr/bin/env python
+
+import logging
+import os
+from dataclasses import dataclass, field, asdict
+from pathlib import Path
+from typing import Optional, Dict, Any, List
+import uuid
+
+from datasets import load_from_disk
+import transformers
+from torch.utils.data import DataLoader
+from transformers import (set_seed, HfArgumentParser, TrainingArguments,
+                          DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
+from transformers.optimization import get_linear_schedule_with_warmup
+from transformers.trainer_utils import is_main_process
+from transformers.trainer import Trainer
+from torch_optimizer import Lamb
+import torch
+
+import hivemind
+
+logger = logging.getLogger(__name__)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+
+
+@dataclass
+class CollaborationArguments:
+    """ define how peers interact with each other while training"""
+
+    # primary parameters
+    initial_peers: List[str]  # one or more peers (comma-separated) that will welcome you into the collaboration
+    experiment_prefix: str  # a unique "name" of this experiment, used to store metadata on the DHT
+    averaging_expiration: float = 5.0  # averaging group will wait for stragglers for at most this many seconds
+    averaging_timeout: float = 30.0  # give up on averaging step after this many seconds
+    target_batch_size: int = 4096  # perform optimizer step after all peers collectively accumulate this many samples
+    client_mode: bool = False  # if True, runs training without incoming connections, in a firewall-compatible mode
+    trainer_uuid: str = uuid.uuid4().hex  # this peer's name - used when publishing metadata to DHT, default = random
+
+    # optional tweaks
+    target_group_size: int = 64  # maximum group size for all-reduce
+    metadata_expiration: float = 30  # peer's metadata will be removed if not updated in this many seconds
+    statistics_expiration: float = 600  # statistics will be removed if not updated in this many seconds
+    dht_listen_on: str = '[::]:*'  # network interface used for incoming DHT communication. Default: all ipv6
+    listen_on: str = '[::]:*'  # network interface used for incoming averager communication. Default: all ipv6
+    endpoint: Optional[str] = None  # this node's IP for inbound connections, used when running from behind a proxy
+    batch_size_lead: int = 0  # optional: begin looking for group in advance, this many samples before target_batch_size
+    compression: str = 'FLOAT16'  # use this compression when averaging parameters/gradients
+
+    min_refresh_period: float = 0.5  # wait for at least this many seconds before fetching new collaboration state
+    max_refresh_period: float = 30  # wait for at most this many seconds before fetching new collaboration state
+    default_refresh_period: float = 3  # attempt to fetch collaboration state every this often until successful
+    expected_drift_peers: float = 3  # trainer assumes that this many new peers can join per step
+    expected_drift_rate: float = 0.2  # trainer assumes that this fraction of current size can join per step
+
+    bandwidth: float = 100.0  # available network bandwidth, in mbps (used for load balancing in all-reduce)
+    performance_ema_alpha: float = 0.1  # uses this alpha for moving average estimate of samples per second
+
+
+@dataclass
+class DatasetArguments:
+    dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
+                                        metadata={"help": "Path to the tokenized dataset"})
+    tokenizer_path: Optional[str] = field(default='./data/tokenizer',
+                                          metadata={"help": "Path to the tokenizer"})
+    config_path: Optional[str] = field(
+        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
+        metadata={"help": "Path to the model config"})
+    cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
+
+
+@dataclass
+class AlbertTrainingArguments(TrainingArguments):
+    dataloader_num_workers: int = 4
+    per_device_train_batch_size: int = 4
+    per_device_eval_batch_size: int = 4
+    gradient_accumulation_steps: int = 2
+    seq_length: int = 512
+
+    max_steps: int = 1_000_000  # Albert is actually ready after 125000 steps
+    learning_rate: float = 0.00176
+    warmup_steps: int = 5000
+    adam_epsilon: float = 1e-6
+    weight_decay: float = 0.01
+    max_grad_norm: float = 1.0
+    clamp_value: float = 10000.0
+
+    fp16: bool = True
+    fp16_opt_level: str = 'O2'
+    do_train: bool = True
+
+    logging_steps: int = 100
+    save_total_limit: int = 2
+    save_steps: int = 500
+
+
+def setup_logging(training_args):
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
+    )
+
+    # Log on each process the small summary:
+    logger.warning(
+        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+    )
+    # Set the verbosity to info of the Transformers logger (on main process only):
+    if is_main_process(training_args.local_rank):
+        transformers.utils.logging.set_verbosity_info()
+        transformers.utils.logging.enable_default_handler()
+        transformers.utils.logging.enable_explicit_format()
+    logger.info("Training/evaluation parameters %s", training_args)
+
+
+def get_model(training_args, config, tokenizer):
+    # Find latest checkpoint in output_dir
+    output_dir = Path(training_args.output_dir)
+    logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
+    latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
+
+    if latest_checkpoint_dir is not None:
+        logger.info(f'Loading model from {latest_checkpoint_dir}')
+        model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
+    else:
+        logger.info(f'Training from scratch')
+        model = AlbertForPreTraining(config)
+        model.resize_token_embeddings(len(tokenizer))
+
+    return model
+
+
+def get_optimizer_and_scheduler(training_args, model):
+    no_decay = ["bias", "LayerNorm.weight"]
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": training_args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+
+    opt = Lamb(
+        optimizer_grouped_parameters,
+        lr=training_args.learning_rate,
+        betas=(training_args.adam_beta1, training_args.adam_beta2),
+        eps=training_args.adam_epsilon,
+        weight_decay=training_args.weight_decay,
+        clamp_value=training_args.clamp_value,
+        debias=True,
+    )
+
+    scheduler = get_linear_schedule_with_warmup(
+        opt,
+        num_warmup_steps=training_args.warmup_steps,
+        num_training_steps=training_args.max_steps
+    )
+
+    return opt, scheduler
+
+
+class CollaborativeCallback(transformers.TrainerCallback):
+    def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
+                 model: torch.nn.Module, trainer_uuid: str, statistics_expiration: float):
+        super().__init__()
+        self.model = model
+        self.dht, self.collaborative_optimizer = dht, optimizer
+        self.trainer_uuid, self.statistics_expiration = trainer_uuid, statistics_expiration
+        self.last_reported_collaboration_step = -1
+        self.previous_state = self.get_current_state()
+        self.samples = 0
+        self.steps = 0
+        self.loss = 0
+
+    def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
+                    control: transformers.TrainerControl, **kwargs):
+        control.should_log = True
+        if not self.params_are_finite():
+            self.load_from_state(self.previous_state)
+            return control
+        self.previous_state = self.get_current_state()
+
+        if state.log_history:
+            self.loss += state.log_history[-1]['loss']
+            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+
+                statistics = [self.collaborative_optimizer.local_step,
+                              self.collaborative_optimizer.performance_ema.samples_per_second,
+                              self.samples,
+                              self.loss / self.steps if self.steps else 0]
+                self.loss = 0
+
+                self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
+                               value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                               return_future=True)
+        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.steps = self.collaborative_optimizer.local_steps_accumulated
+
+        return control
+
+    @torch.no_grad()
+    def get_current_state(self) -> Dict[str, Any]:
+        return {
+            'model': self.model.state_dict(),
+            'opt': self.collaborative_optimizer.opt.state_dict()
+        }
+
+    @torch.no_grad()
+    def load_from_state(self, state):
+        self.model.load_state_dict(state['model'])
+        self.collaborative_optimizer.opt.load_state_dict(state['opt'])
+
+    @torch.no_grad()
+    def params_are_finite(self):
+        for param in self.model.parameters():
+            if not torch.all(torch.isfinite(param)):
+                return False
+        return True
+
+
+class NoOpScheduler(LRSchedulerBase):
+    """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
+
+    def get_lr(self):
+        return [group['lr'] for group in self.optimizer.param_groups]
+
+    def print_lr(self, *args, **kwargs):
+        if self.optimizer.scheduler:
+            return self.optimizer.scheduler.print_lr(*args, **kwargs)
+
+    def step(self):
+        logger.debug("Called NoOpScheduler.step")
+        self._last_lr = self.get_lr()
+
+    def state_dict(self):
+        return {}
+
+    def load_state_dict(self, *args, **kwargs):
+        logger.debug("Called NoOpScheduler.load_state_dict")
+
+
+def main():
+    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
+    training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()
+
+    logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
+    if len(collaboration_args.initial_peers) == 0:
+        raise ValueError("Please specify at least one network endpoint in initial peers.")
+
+    collaboration_args_dict = asdict(collaboration_args)
+    setup_logging(training_args)
+
+    # Set seed before initializing model.
+    set_seed(training_args.seed)
+
+    config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
+    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    model = get_model(training_args, config, tokenizer)
+    model.to(training_args.device)
+
+    tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
+    # This data collator will take care of randomly masking the tokens.
+    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
+
+    opt, scheduler = get_optimizer_and_scheduler(training_args, model)
+
+    dht = hivemind.DHT(
+        initial_peers=collaboration_args_dict.pop('initial_peers'),
+        listen=not collaboration_args_dict['client_mode'], listen_on=collaboration_args_dict.pop('dht_listen_on'),
+        endpoint=collaboration_args_dict.pop('endpoint'), start=True)
+
+    total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
+    trainer_uuid = collaboration_args_dict.pop('trainer_uuid')
+    statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
+    adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
+                                 - collaboration_args_dict.pop('batch_size_lead')
+
+    collaborative_optimizer = hivemind.CollaborativeOptimizer(
+        opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
+        compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
+        batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
+        target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
+        verbose=True, start=True, **collaboration_args_dict
+    )
+
+    class TrainerWithIndependentShuffling(Trainer):
+        def get_train_dataloader(self) -> DataLoader:
+            """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
+            torch.manual_seed(hash(trainer_uuid))
+            return super().get_train_dataloader()
+
+    trainer = TrainerWithIndependentShuffling(
+        model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
+        train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
+        eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
+        optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
+        callbacks=[CollaborativeCallback(dht, collaborative_optimizer, model, trainer_uuid, statistics_expiration)]
+    )
+    trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
+    trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
+
+    # Training
+    if training_args.do_train:
+        latest_checkpoint_dir = max(
+            Path(training_args.output_dir).glob('checkpoint*'),
+            default=None,
+            key=os.path.getctime
+        )
+
+        trainer.train(model_path=latest_checkpoint_dir)
+
+
+if __name__ == "__main__":
+    main()

+ 104 - 0
examples/albert/tokenize_wikitext103.py

@@ -0,0 +1,104 @@
+#!/usr/bin/env python
+""" This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
+import random
+from collections import defaultdict
+from functools import partial
+from multiprocessing import cpu_count
+
+import nltk
+from datasets import load_dataset
+from transformers import AlbertTokenizerFast
+
+
+def create_instances_from_document(tokenizer, document, max_seq_length):
+    """Creates `TrainingInstance`s for a single document."""
+    # We DON'T just concatenate all of the tokens from a document into a long
+    # sequence and choose an arbitrary split point because this would make the
+    # next sentence prediction task too easy. Instead, we split the input into
+    # segments "A" and "B" based on the actual "sentences" provided by the user
+    # input.
+    instances = []
+    current_chunk = []
+    current_length = 0
+
+    segmented_sents = list(nltk.sent_tokenize(document))
+
+    for i, sent in enumerate(segmented_sents):
+        current_chunk.append(sent)
+        current_length += len(tokenizer.tokenize(sent))
+        if i == len(segmented_sents) - 1 or current_length >= max_seq_length:
+            if len(current_chunk) > 1:
+                # `a_end` is how many segments from `current_chunk` go into the `A`
+                # (first) sentence.
+                a_end = random.randint(1, len(current_chunk) - 1)
+
+                tokens_a = []
+                for j in range(a_end):
+                    tokens_a.append(current_chunk[j])
+
+                tokens_b = []
+
+                for j in range(a_end, len(current_chunk)):
+                    tokens_b.append(current_chunk[j])
+
+                if random.random() < 0.5:
+                    # Random next
+                    is_random_next = True
+                    # Note(mingdachen): in this case, we just swap tokens_a and tokens_b
+                    tokens_a, tokens_b = tokens_b, tokens_a
+                else:
+                    # Actual next
+                    is_random_next = False
+
+                assert len(tokens_a) >= 1
+                assert len(tokens_b) >= 1
+
+                instance = tokenizer(
+                    ' '.join(tokens_a),
+                    ' '.join(tokens_b),
+                    truncation='longest_first',
+                    max_length=max_seq_length,
+                    # We use this option because DataCollatorForLanguageModeling
+                    # is more efficient when it receives the `special_tokens_mask`.
+                    return_special_tokens_mask=True,
+                )
+                assert len(instance['input_ids']) <= max_seq_length
+                instance["sentence_order_label"] = 1 if is_random_next else 0
+                instances.append(instance)
+
+            current_chunk = []
+            current_length = 0
+
+    return instances
+
+
+def tokenize_function(tokenizer, examples):
+    # Remove empty texts
+    texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
+
+    new_examples = defaultdict(list)
+
+    for text in texts:
+        instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
+        for instance in instances:
+            for key, value in instance.items():
+                new_examples[key].append(value)
+
+    return new_examples
+
+
+if __name__ == '__main__':
+    random.seed(0)
+    nltk.download('punkt')
+    tokenizer = AlbertTokenizerFast.from_pretrained('albert-large-v2')
+    wikitext = load_dataset('wikitext', 'wikitext-103-v1', cache_dir='./data/cache')
+
+    tokenized_datasets = wikitext.map(
+        partial(tokenize_function, tokenizer),
+        batched=True,
+        num_proc=cpu_count(),
+        remove_columns=["text"],
+    )
+
+    tokenized_datasets.save_to_disk('./data/albert_tokenized_wikitext')
+    tokenizer.save_pretrained('./data/tokenizer')

+ 25 - 12
hivemind/optim/collaborative.py

@@ -31,8 +31,8 @@ class CollaborationState:
     def ready_for_step(self):
         return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
 
-    def register_step(self):
-        self.optimizer_step += 1
+    def register_step(self, local_step: int):
+        self.optimizer_step = max(local_step, self.optimizer_step)
         self.samples_accumulated = 0
         self.eta_next_step = float('inf')
 
@@ -62,6 +62,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: the expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
       refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
     :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
+    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
     :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
@@ -72,6 +73,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
      the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
+    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
     :param kwargs: additional parameters forwarded to DecentralizedAverager
     :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
@@ -81,8 +83,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
                  min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
                  expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
-                 metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
-                 reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None, **kwargs):
+                 metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
+                 reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
+                 client_mode: bool = False, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
@@ -93,6 +96,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
+        self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
@@ -113,7 +117,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     def _make_averager(self, **kwargs):
         return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=True,
-                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout, **kwargs)
+                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
+                                listen=not self.client_mode, **kwargs)
 
     @property
     def local_step(self) -> int:
@@ -121,7 +126,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step
+        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
     def is_alive(self) -> bool:
         return self.averager.is_alive()
@@ -149,6 +154,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
         if not self.is_synchronized:
+            logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             return
 
@@ -157,6 +163,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                            f"but metadata expired in {self.metadata_expiration} s.")
 
         self.accumulate_grads_(batch_size)
+
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
@@ -166,7 +173,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, "Averaging parameters and gradients with peers...")
+        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
         self.collaboration_state = self.fetch_collaboration_state()
         self.collaboration_state_updated.set()
 
@@ -177,26 +184,32 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.performance_ema.pause(), self.lock_collaboration_state:
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
+            current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
-                output = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                try:
+                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    if group_info:
+                        logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+                except Exception as e:
+                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {e}.")
+
             else:
                 logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
                                                  f"{self.collaboration_state.num_peers} peer(s).")
-                output = None
-                self.averager.local_step += 1
 
             self.opt.step()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
-            self.collaboration_state.register_step()
+            self.collaboration_state.register_step(current_step + 1)
+            self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.update_scheduler()
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
-            return output
+            return group_info
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
         """ pytorch-internal gradient buffers """