Browse Source

Simplify argument parsing, update docs in ALBERT example (#315)

* Simplify argument parsing, update docs in ALBERT example

* Replace whatsmyip with requests.get to find external IP address


Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Max Ryabinin 4 years ago
parent
commit
d97fede72c

+ 111 - 44
examples/albert/README.md

@@ -1,20 +1,31 @@
 # Training ALBERT with decentralized averaging
 # Training ALBERT with decentralized averaging
 
 
-This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the WikiText103 dataset. It uses huggingface [datasets](https://github.com/huggingface/datasets) and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates, using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the
+WikiText103 dataset. It uses Hugging Face [datasets](https://github.com/huggingface/datasets)
+and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates,
+using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+
+## Preparation
 
 
-### Preparation
 * Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
 * Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
 * Dependencies: `pip install -r requirements.txt`
 * Dependencies: `pip install -r requirements.txt`
 * Preprocess data: `python tokenize_wikitext103.py`
 * Preprocess data: `python tokenize_wikitext103.py`
-* Upload an archive preprocessed data to somewhere volunteers can reach, example: `https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103_preprocessed.tar`
-
+* Upload the data to a publicly available location or ask volunteers to preprocess it locally
 
 
 ## Running an experiment
 ## Running an experiment
-- Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
-   - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-   - Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
-   - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
-   - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
+
+### First peer
+
+Run the first DHT peer to welcome trainers and record training statistics (e.g., loss and performance):
+
+- In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
+  & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
+- Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
+- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
+  due to naming conventions.
+- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the
+  same project name.
+
 ```
 ```
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
 [2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:42] Running a DHT peer. To connect other peers to this one, use --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
 [2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:42] Running a DHT peer. To connect other peers to this one, use --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
@@ -33,34 +44,47 @@ wandb: Run `wandb offline` to turn off syncing.
 [2021/04/19 02:40:37.541][INFO][__main__.<module>:194] Step #3  loss = 11.02886
 [2021/04/19 02:40:37.541][INFO][__main__.<module>:194] Step #3  loss = 11.02886
 ```
 ```
 
 
-- To join a collaboration with a GPU trainer,
-  - install the same dependencies (minus the `wandb` and `whatsmyip`), download the data and unpack it to the experiment folder,
-  - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
-  - run:
-    ```bash
-    python run_trainer.py \
-    --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
-    --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
-    ```
-
-    Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing trainers)
-    collected from the first lines of their terminal output. For the example above, the multiaddresses would be:
-    ```
-    --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
-    ```
-
-    __Note:__ a [multiaddress](https://docs.libp2p.io/concepts/addressing/) is a format for encoding multiple layers of addressing information
-    that supports a number of different protocols. In hivemind, we typically operate with multiaddresses
-    that contain a [libp2p](https://libp2p.io/) peer ID (e.g. `/p2p/XXXX`) together with the information about how to reach it
-    (e.g. the IPv4 address and TCP port `/ip4/8.8.8.8/tcp/31337` or
-    the information about a relay used for [NAT traversal](https://docs.libp2p.io/concepts/nat/)).
-
-    You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
-    If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.
+### GPU trainers
+
+To join the collaboration with a GPU trainer,
+
+- Install the same dependencies (without `wandb` and `requests`), download the data and unpack it to the experiment
+  folder;
+- If necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config`
+  (see [default paths](./arguments.py#L117-L134) for reference)
+- Run:
+  ```bash
+  python run_trainer.py \
+  --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
+  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+  ```
+
+  Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
+  trainers)
+  collected from the first lines of their terminal output. For the example above, the multiaddresses would be:
+  ```
+  --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+  ```
+
+  <details>
+    <summary>What is a multiaddress?</summary>
+
+  A [multiaddress](https://docs.libp2p.io/concepts/addressing/) is a format for encoding multiple layers of addressing
+  information that supports a number of different protocols.
+
+  In hivemind, we typically operate with multiaddresses that contain a [libp2p](https://libp2p.io/) peer ID (
+  e.g. `/p2p/XXXX`) together with the information about how to reach it
+  (e.g., the IPv4 address and TCP port `/ip4/8.8.8.8/tcp/31337` or the information about a relay used
+  for [NAT traversal](https://docs.libp2p.io/concepts/nat/)).
+  </details>
+
+  You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
+  If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.
 
 
 See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.
 See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.
 
 
 As the peer begins training, it will periodically report training logs in the following form:
 As the peer begins training, it will periodically report training logs in the following form:
+
 ```
 ```
 [...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
 [...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
 [...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
 [...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
@@ -80,31 +104,69 @@ For convenience, you can view (and share!) the learning curves of your collabora
   <img src="https://user-images.githubusercontent.com/3491902/115177859-bed5e100-a0d8-11eb-82bc-55d1b12d335d.png">
   <img src="https://user-images.githubusercontent.com/3491902/115177859-bed5e100-a0d8-11eb-82bc-55d1b12d335d.png">
 </p>
 </p>
 
 
-
 ## Tips and tricks
 ## Tips and tricks
 
 
 Finally, we provide best practices for running collaborative experiments of different sizes.
 Finally, we provide best practices for running collaborative experiments of different sizes.
 
 
 ### Hosting the data
 ### Hosting the data
-For small experiments (3-16 peers, <1GB data), you can use a free-tier file hosting that has a convenient way to [download with curl/wget](https://superuser.com/questions/470664/how-to-download-dropbox-files-using-wget-command). However, these services are not meant for high load and could ban you for generating too much traffic. If you want to scale up, you could either use an S3-like storage from [any](https://aws.amazon.com/s3/) [cloud](https://cloud.google.com/storage) [provider](https://cloud.google.com/storage) or host the data [yourself]((https://gist.github.com/willurd/5720255)). Large data files (>5GB) will take long to download; we recommend splitting them into chunks and implementing a custom dataloader that can load chunks on the fly. Finally, the most _comme il faut_ solution to sharing large datasets is to use [academic torrents](https://academictorrents.com/).
+
+For small experiments (3-16 peers, <1GB data), you can use a free-tier file hosting that has a convenient way
+to [download with curl/wget](https://superuser.com/questions/470664/how-to-download-dropbox-files-using-wget-command).
+However, these services are not meant for high load and could ban you for generating too much traffic. If you want to
+scale up, you could either use an S3-like storage from [any](https://aws.amazon.com/s3/)
+[cloud](https://cloud.google.com/storage) [provider](https://cloud.yandex.com/en-ru/services/storage) or host the data
+[yourself]((https://gist.github.com/willurd/5720255)). Large data files (>5GB) will take long to download; we recommend
+splitting them into chunks and implementing a custom dataloader that can load chunks on the fly. Finally, the most _
+comme il faut_ solution to sharing large datasets is to use [academic torrents](https://academictorrents.com/).
 
 
 ### run_training_monitor.py
 ### run_training_monitor.py
-This peer exists solely to welcome other peers onto the DHT and track learning progress. It requires neither GPU nor high bandwidth, the only prerequisite is that coordinator should have high uptime. If no high uptime server is available, one can also run multiple coordinators on different servers and list all of them as `--initial_peers`. The system will stay up as long as at least one coordinator is available. For short- to mid-term experiments you can host coordinator on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
+
+This peer exists solely to welcome other peers onto the DHT and track learning progress. It requires neither GPU nor
+high bandwidth, the only prerequisite is high uptime. If no high uptime server is available, one can also run multiple
+monitors on different servers and list all of them as `--initial_peers`. The system will maintain its integrity as long
+as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
+monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
 
 ### Tuning for hardware/network
 ### Tuning for hardware/network
-The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should be able to process one local microbatch each `0.5~1` seconds (see trainer's progress bar). To achieve that, we recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. The example trainer supports multiple GPUs via DataParallel. However, using advanced distributed training strategies (e.g. [ZeRO-3](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)) will require changes in `run_trainer.py`.
+
+The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
+incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example
+below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
+set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
+be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
+recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. 
+
+The example trainer supports
+multiple GPUs via DataParallel. However, using advanced distributed training strategies (
+e.g. [ZeRO-3](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)) will require changes in `run_trainer.py`.
 
 
 ### Using public GPU providers
 ### Using public GPU providers
-There are awesome services like [Google Colab](https://colab.research.google.com/), [Kaggle kernels](https://www.kaggle.com/dansbecker/running-kaggle-kernels-with-a-gpu) or[Paperspace](https://gradient.paperspace.com/free-gpu) that provide free GPUs. These services usually come with significant limitations (e.g. last gen GPUs, reset every few hours), but they allow just about anyone to join your collaborative experiment. Here's how to best use them.
-  - before you begin, __read the rules carefully__. Most free-tier GPU services allow only one GPU per user and using more than one account will get you banned. It is **your** duty to make sure that collaborators won't get in trouble for helping you.
-  - most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we tested this code on preemptible [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/) nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
-  - you can create starter notebooks to make it more convenient for collaborators to join your training run ([example](https://colab.research.google.com/gist/yhn112/e858cb841c73879d8ef98a84e03b43e7/collaborative-training-v0-10.ipynb)). Ideally, joining collaboration should take at most a couple of clicks.
+
+There are awesome services like [Google Colab](https://colab.research.google.com/),
+[Kaggle kernels](https://www.kaggle.com/dansbecker/running-kaggle-kernels-with-a-gpu)
+or [Paperspace](https://gradient.paperspace.com/free-gpu) that provide free GPUs. These services usually come with
+significant limitations (e.g., last gen GPUs, reset every few hours), but they allow just about anyone to join your
+collaborative experiment. Here's how to best use them:
+
+- Before you begin, __read the rules carefully__. Most free-tier GPU services allow only one GPU per user and using
+  more than one account will get you banned. It is **your** duty to make sure that collaborators won't get in trouble
+  for helping you.
+- Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
+  below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
+  with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
+  tested this code on preemptible 
+  [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
+  nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
+- You can create starter notebooks to make it more convenient for collaborators to join your training
+  run ([example](https://colab.research.google.com/gist/yhn112/e858cb841c73879d8ef98a84e03b43e7/collaborative-training-v0-10.ipynb)).
+  Ideally, joining collaboration should take at most a couple of clicks.
 
 
 Here's an example of a full trainer script for Google Colab:
 Here's an example of a full trainer script for Google Colab:
+
 ```bash
 ```bash
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
-!curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
+!curl -L YOUR_HOSTED_DATA | tar xzf -
 !ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
 !ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
  --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
  --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
  --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
  --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
@@ -113,6 +175,11 @@ Here's an example of a full trainer script for Google Colab:
 ```
 ```
 
 
 ### Using IPFS
 ### Using IPFS
-If the initial peers for your experiment are located behind NAT and/or you have any trouble with figuring out their public IP addresses and ports, you can set up hivemind to use the [IPFS](https://ipfs.io) network to find the route to your peers automatically. To do this, you should specify the `--use_ipfs` option on all peers (both training monitors and trainers) you are starting.
 
 
-After that, it is enough to provide only a [libp2p](https://libp2p.io/) peer ID (e.g. `/p2p/XXXX`) for each initial peer. No other information (like IP addresses or TCP/UDP ports) is required.
+If the initial peers for your experiment are located behind NAT and/or you have any trouble with figuring out their
+public IP addresses and ports, you can set up hivemind to use the [IPFS](https://ipfs.io) network to find the route to
+your peers automatically. To do this, you should specify the `--use_ipfs` option on all peers you are starting
+(both trainers and monitors).
+
+After that, it is enough to provide only a [libp2p](https://libp2p.io/) peer ID (e.g. `/p2p/XXXX`) for each initial
+peer. No other information (like IP addresses or TCP/UDP ports) is required.

+ 1 - 1
examples/albert/arguments.py

@@ -97,7 +97,7 @@ class CollaborativeOptimizerArguments:
 
 
 
 
 @dataclass
 @dataclass
-class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
+class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
     )

+ 1 - 1
examples/albert/requirements.txt

@@ -3,5 +3,5 @@ datasets>=1.5.0
 torch_optimizer>=0.1.0
 torch_optimizer>=0.1.0
 wandb>=0.10.26
 wandb>=0.10.26
 sentencepiece
 sentencepiece
-whatsmyip
+requests
 nltk>=3.6.2
 nltk>=3.6.2

+ 24 - 33
examples/albert/run_trainer.py

@@ -10,24 +10,16 @@ import torch
 import transformers
 import transformers
 from datasets import load_from_disk
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
-from transformers import (
-    set_seed,
-    HfArgumentParser,
-    TrainingArguments,
-    DataCollatorForLanguageModeling,
-    AlbertTokenizerFast,
-    AlbertConfig,
-    AlbertForPreTraining,
-)
+from torch_optimizer import Lamb
+from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
+from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.optimization import get_linear_schedule_with_warmup
-from transformers.trainer_utils import is_main_process
 from transformers.trainer import Trainer
 from transformers.trainer import Trainer
-from torch_optimizer import Lamb
+from transformers.trainer_utils import is_main_process
 
 
 import hivemind
 import hivemind
 import utils
 import utils
-from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments
-
+from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
@@ -209,14 +201,13 @@ class NoOpScheduler(LRSchedulerBase):
 
 
 
 
 def main():
 def main():
-    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
-    training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
+    training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
 
 
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
     if len(collaboration_args.initial_peers) == 0:
     if len(collaboration_args.initial_peers) == 0:
         raise ValueError("Please specify at least one network endpoint in initial peers.")
         raise ValueError("Please specify at least one network endpoint in initial peers.")
 
 
-    collaboration_args_dict = asdict(collaboration_args)
     setup_logging(training_args)
     setup_logging(training_args)
 
 
     # Set seed before initializing model.
     # Set seed before initializing model.
@@ -233,40 +224,38 @@ def main():
 
 
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
 
 
-    validators, local_public_key = utils.make_validators(collaboration_args_dict["experiment_prefix"])
+    validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
+
     dht = hivemind.DHT(
     dht = hivemind.DHT(
         start=True,
         start=True,
-        initial_peers=collaboration_args_dict.pop("initial_peers"),
-        listen=not collaboration_args_dict["client_mode"],
+        initial_peers=collaboration_args.initial_peers,
+        listen=not collaboration_args.client_mode,
         record_validators=validators,
         record_validators=validators,
-        use_ipfs=collaboration_args_dict["use_ipfs"],
-        host_maddrs=collaboration_args_dict.pop("host_maddrs"),
-        announce_maddrs=collaboration_args_dict.pop("announce_maddrs"),
+        use_ipfs=collaboration_args.use_ipfs,
+        host_maddrs=collaboration_args.host_maddrs,
+        announce_maddrs=collaboration_args.announce_maddrs,
     )
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args_dict.pop("use_ipfs"))
+    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     if torch.cuda.device_count() != 0:
     if torch.cuda.device_count() != 0:
         total_batch_size_per_step *= torch.cuda.device_count()
         total_batch_size_per_step *= torch.cuda.device_count()
 
 
-    statistics_expiration = collaboration_args_dict.pop("statistics_expiration")
-    adjusted_target_batch_size = collaboration_args_dict.pop("target_batch_size") - collaboration_args_dict.pop(
-        "batch_size_lead"
-    )
+    adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
 
     collaborative_optimizer = hivemind.CollaborativeOptimizer(
     collaborative_optimizer = hivemind.CollaborativeOptimizer(
         opt=opt,
         opt=opt,
         dht=dht,
         dht=dht,
         scheduler=scheduler,
         scheduler=scheduler,
-        prefix=collaboration_args_dict.pop("experiment_prefix"),
-        compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop("compression")),
+        prefix=collaboration_args.experiment_prefix,
+        compression_type=hivemind.utils.CompressionType.Value(collaboration_args.compression),
         batch_size_per_step=total_batch_size_per_step,
         batch_size_per_step=total_batch_size_per_step,
-        throughput=collaboration_args_dict.pop("bandwidth"),
+        throughput=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,
         target_batch_size=adjusted_target_batch_size,
-        client_mode=collaboration_args_dict.pop("client_mode"),
+        client_mode=collaboration_args.client_mode,
         verbose=True,
         verbose=True,
         start=True,
         start=True,
-        **collaboration_args_dict,
+        **asdict(averager_args),
     )
     )
 
 
     class TrainerWithIndependentShuffling(Trainer):
     class TrainerWithIndependentShuffling(Trainer):
@@ -284,7 +273,9 @@ def main():
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
         callbacks=[
         callbacks=[
-            CollaborativeCallback(dht, collaborative_optimizer, model, local_public_key, statistics_expiration)
+            CollaborativeCallback(
+                dht, collaborative_optimizer, model, local_public_key, collaboration_args.statistics_expiration
+            )
         ],
         ],
     )
     )
     trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
     trainer.remove_callback(transformers.trainer_callback.PrinterCallback)

+ 44 - 43
examples/albert/run_training_monitor.py

@@ -6,23 +6,21 @@ from dataclasses import asdict, dataclass, field
 from ipaddress import ip_address
 from ipaddress import ip_address
 from typing import Optional
 from typing import Optional
 
 
+import requests
 import torch
 import torch
 import wandb
 import wandb
 from torch_optimizer import Lamb
 from torch_optimizer import Lamb
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
-from whatsmyip.ip import get_ip
-from whatsmyip.providers import GoogleDnsProvider
 
 
 import hivemind
 import hivemind
 import utils
 import utils
 from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 
 
-
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
 @dataclass
 @dataclass
-class CoordinatorArguments(BaseTrainingArguments):
+class TrainingMonitorArguments(BaseTrainingArguments):
     """
     """
     Note: You might want to have several initial peers so that if one dies,
     Note: You might want to have several initial peers so that if one dies,
     new workers still can join the collaboration via alive initial peers' addresses.
     new workers still can join the collaboration via alive initial peers' addresses.
@@ -35,29 +33,25 @@ class CoordinatorArguments(BaseTrainingArguments):
             "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
             "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
         },
         },
     )
     )
-    refresh_period: float = field(
-        default=30, metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
+    refresh_period: float = field(default=30, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
+    wandb_project: Optional[str] = field(
+        default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
     )
     )
-    wandb_project: Optional[str] = field(default=None, metadata={"help": "Learning curves will be published there"})
     save_checkpoint_step_interval: int = field(
     save_checkpoint_step_interval: int = field(
-        default=5, metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
+        default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
     )
     )
     model_config_path: str = field(
     model_config_path: str = field(
         default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
         default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
         metadata={"help": "Path to the model config"},
         metadata={"help": "Path to the model config"},
     )
     )
     repo_path: Optional[str] = field(
     repo_path: Optional[str] = field(
-        default=None,
-        metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"},
+        default=None, metadata={"help": "Path to local repository to store the model and optimizer states"}
     )
     )
     repo_url: Optional[str] = field(
     repo_url: Optional[str] = field(
-        default=None,
-        metadata={
-            "help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"
-        },
+        default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
     )
     )
     upload_interval: Optional[float] = field(
     upload_interval: Optional[float] = field(
-        default=None, metadata={"help": "Coordinator will upload model once in this many seconds"}
+        default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
     )
     store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
     store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
@@ -65,18 +59,18 @@ class CoordinatorArguments(BaseTrainingArguments):
 class CheckpointHandler:
 class CheckpointHandler:
     def __init__(
     def __init__(
         self,
         self,
-        coordinator_args: CoordinatorArguments,
+        monitor_args: TrainingMonitorArguments,
         collab_optimizer_args: CollaborativeOptimizerArguments,
         collab_optimizer_args: CollaborativeOptimizerArguments,
         averager_args: AveragerArguments,
         averager_args: AveragerArguments,
         dht: hivemind.DHT,
         dht: hivemind.DHT,
     ):
     ):
-        self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
-        self.repo_path = coordinator_args.repo_path
-        self.repo_url = coordinator_args.repo_url
-        self.upload_interval = coordinator_args.upload_interval
+        self.save_checkpoint_step_interval = monitor_args.save_checkpoint_step_interval
+        self.repo_path = monitor_args.repo_path
+        self.repo_url = monitor_args.repo_url
+        self.upload_interval = monitor_args.upload_interval
         self.previous_step = -1
         self.previous_step = -1
 
 
-        config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
+        config = AlbertConfig.from_pretrained(monitor_args.model_config_path)
         self.model = AlbertForPreTraining(config)
         self.model = AlbertForPreTraining(config)
 
 
         no_decay = ["bias", "LayerNorm.weight"]
         no_decay = ["bias", "LayerNorm.weight"]
@@ -140,43 +134,47 @@ class CheckpointHandler:
         logger.info("Saving optimizer")
         logger.info("Saving optimizer")
         torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         self.previous_timestamp = time.time()
-        logger.info("Started uploading model to Hub")
+        logger.info("Started uploading to Model Hub")
         self.model.push_to_hub(
         self.model.push_to_hub(
             repo_name=self.repo_path,
             repo_name=self.repo_path,
             repo_url=self.repo_url,
             repo_url=self.repo_url,
             commit_message=f"Step {current_step}, loss {current_loss:.3f}",
             commit_message=f"Step {current_step}, loss {current_loss:.3f}",
         )
         )
-        logger.info("Finished uploading model to Hub")
+        logger.info("Finished uploading to Model Hub")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
-    coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
+    monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
 
-    if coordinator_args.use_google_dns:
-        address = get_ip(GoogleDnsProvider)
-        logger.info(f"Received public IP address of this machine from Google DNS: {address}")
+    if monitor_args.use_google_dns:
+        request = requests.get("https://api.ipify.org")
+        request.raise_for_status()
+
+        address = request.text
+        logger.info(f"Received public IP address of this machine: {address}")
         version = ip_address(address).version
         version = ip_address(address).version
-        coordinator_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
+        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
 
 
-    experiment_prefix = coordinator_args.experiment_prefix
+    experiment_prefix = monitor_args.experiment_prefix
     validators, local_public_key = utils.make_validators(experiment_prefix)
     validators, local_public_key = utils.make_validators(experiment_prefix)
+
     dht = hivemind.DHT(
     dht = hivemind.DHT(
         start=True,
         start=True,
-        initial_peers=coordinator_args.initial_peers,
+        initial_peers=monitor_args.initial_peers,
         record_validators=validators,
         record_validators=validators,
-        use_ipfs=coordinator_args.use_ipfs,
-        host_maddrs=coordinator_args.host_maddrs,
-        announce_maddrs=coordinator_args.announce_maddrs,
+        use_ipfs=monitor_args.use_ipfs,
+        host_maddrs=monitor_args.host_maddrs,
+        announce_maddrs=monitor_args.announce_maddrs,
     )
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=coordinator_args.use_ipfs)
+    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
 
-    if coordinator_args.wandb_project is not None:
-        wandb.init(project=coordinator_args.wandb_project)
+    if monitor_args.wandb_project is not None:
+        wandb.init(project=monitor_args.wandb_project)
 
 
     current_step = 0
     current_step = 0
-    if coordinator_args.store_checkpoins:
-        checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
+    if monitor_args.store_checkpoins:
+        checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
 
 
     while True:
     while True:
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
@@ -184,18 +182,20 @@ if __name__ == "__main__":
             metrics_dict = metrics_dict.value
             metrics_dict = metrics_dict.value
             metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
             metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
             latest_step = max(item.step for item in metrics)
             latest_step = max(item.step for item in metrics)
+
             if latest_step != current_step:
             if latest_step != current_step:
                 logger.debug(f"Got metrics from {len(metrics)} peers")
                 logger.debug(f"Got metrics from {len(metrics)} peers")
 
 
                 for i, metrics_for_peer in enumerate(metrics):
                 for i, metrics_for_peer in enumerate(metrics):
                     logger.debug(f"{i} peer {metrics_for_peer}")
                     logger.debug(f"{i} peer {metrics_for_peer}")
+
                 current_step = latest_step
                 current_step = latest_step
                 alive_peers = 0
                 alive_peers = 0
-                num_batches = 0
                 sum_loss = 0
                 sum_loss = 0
                 num_samples = 0
                 num_samples = 0
                 sum_perf = 0
                 sum_perf = 0
                 sum_mini_steps = 0
                 sum_mini_steps = 0
+
                 for item in metrics:
                 for item in metrics:
                     sum_loss += item.loss
                     sum_loss += item.loss
                     alive_peers += 1
                     alive_peers += 1
@@ -205,7 +205,7 @@ if __name__ == "__main__":
                 current_loss = sum_loss / sum_mini_steps
                 current_loss = sum_loss / sum_mini_steps
                 logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
                 logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
 
 
-                if coordinator_args.wandb_project is not None:
+                if monitor_args.wandb_project is not None:
                     wandb.log(
                     wandb.log(
                         {
                         {
                             "loss": current_loss,
                             "loss": current_loss,
@@ -215,10 +215,11 @@ if __name__ == "__main__":
                             "step": latest_step,
                             "step": latest_step,
                         }
                         }
                     )
                     )
-                if coordinator_args.store_checkpoins:
+
+                if monitor_args.store_checkpoins:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():
                         if checkpoint_handler.is_time_to_upload():
                             checkpoint_handler.upload_checkpoint(current_loss)
                             checkpoint_handler.upload_checkpoint(current_loss)
         logger.debug("Peer is still alive...")
         logger.debug("Peer is still alive...")
-        time.sleep(coordinator_args.refresh_period)
+        time.sleep(monitor_args.refresh_period)

+ 6 - 4
examples/albert/tokenize_wikitext103.py

@@ -1,19 +1,21 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
-""" This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
+""" This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
 import random
 import random
 from functools import partial
 from functools import partial
-from multiprocessing import cpu_count
 
 
 import nltk
 import nltk
 from datasets import load_dataset
 from datasets import load_dataset
 from transformers import AlbertTokenizerFast
 from transformers import AlbertTokenizerFast
 
 
-
 COLUMN_NAMES = ("attention_mask", "input_ids", "sentence_order_label", "special_tokens_mask", "token_type_ids")
 COLUMN_NAMES = ("attention_mask", "input_ids", "sentence_order_label", "special_tokens_mask", "token_type_ids")
 
 
 
 
 def create_instances_from_document(tokenizer, document, max_seq_length):
 def create_instances_from_document(tokenizer, document, max_seq_length):
-    """Creates `TrainingInstance`s for a single document."""
+    """
+    Creates training instances from a single document.
+    Reuses code from the original ALBERT implementation (Google AI, 2018)
+    https://github.com/google-research/albert/blob/master/create_pretraining_data.py#L267
+    """
     # We DON'T just concatenate all of the tokens from a document into a long
     # We DON'T just concatenate all of the tokens from a document into a long
     # sequence and choose an arbitrary split point because this would make the
     # sequence and choose an arbitrary split point because this would make the
     # next sentence prediction task too easy. Instead, we split the input into
     # next sentence prediction task too easy. Instead, we split the input into