Просмотр исходного кода

Add state checkpointing and uploading in coordinator (#241)

Resolves #237

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexey Bukhtiyarov 4 лет назад
Родитель
Сommit
01103cf991

+ 150 - 0
examples/albert/arguments.py

@@ -0,0 +1,150 @@
+from typing import Optional, List
+from dataclasses import dataclass, field
+
+from transformers import TrainingArguments
+
+
+@dataclass
+class BaseTrainingArguments:
+    experiment_prefix: str = field(
+        metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
+    )
+    initial_peers: List[str] = field(
+        default_factory=list,
+        metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"}
+    )
+    dht_listen_on: str = field(
+        default="[::]:*",
+        metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
+    )
+
+
+@dataclass
+class AveragerArguments:
+    averaging_expiration: float = field(
+        default=5.0,
+        metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
+    )
+    averaging_timeout: float = field(
+        default=30.0,
+        metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    listen_on: str = field(
+        default="[::]:*",
+        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"}
+    )
+    min_refresh_period: float = field(
+        default=0.5,
+        metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
+    )
+    max_refresh_period: float = field(
+        default=30,
+        metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
+    )
+    default_refresh_period: float = field(
+        default=3,
+        metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
+    )
+    expected_drift_peers: float = field(
+        default=3,
+        metadata={"help": "Trainer assumes that this many new peers can join per step"}
+    )
+    expected_drift_rate: float = field(
+        default=0.2,
+        metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
+    )
+    performance_ema_alpha: float = field(
+        default=0.1,
+        metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
+    )
+    target_group_size: int = field(
+        default=256,
+        metadata={"help": "Maximum group size for all-reduce"}
+    )
+    metadata_expiration: float = field(
+        default=30,
+        metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+    )
+
+
+@dataclass
+class CollaborativeOptimizerArguments:
+    target_batch_size: int = field(
+        default=4096,
+        metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}
+    )
+    client_mode: bool = field(
+        default=False,
+        metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}
+    )
+    batch_size_lead: int = field(
+        default=0,
+        metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"}
+    )
+    bandwidth: float = field(
+        default=100.0,
+        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}
+    )
+    compression: str = field(
+        default="FLOAT16",
+        metadata={"help": "Use this compression when averaging parameters/gradients"}
+    )
+
+
+@dataclass
+class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
+    statistics_expiration: float = field(
+        default=600,
+        metadata={"help": "Statistics will be removed if not updated in this many seconds"}
+    )
+    endpoint: Optional[str] = field(
+        default=None,
+        metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"}
+    )
+
+
+@dataclass
+class DatasetArguments:
+    dataset_path: Optional[str] = field(
+        default='data/albert_tokenized_wikitext',
+        metadata={"help": "Path to the tokenized dataset"}
+    )
+    tokenizer_path: Optional[str] = field(
+        default='data/tokenizer',
+        metadata={"help": "Path to the tokenizer"}
+    )
+    config_path: Optional[str] = field(
+        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
+        metadata={"help": "Path to the model config"}
+    )
+    cache_dir: Optional[str] = field(
+        default='data',
+        metadata={"help": "Path to the cache"}
+    )
+
+
+@dataclass
+class AlbertTrainingArguments(TrainingArguments):
+    dataloader_num_workers: int = 4
+    per_device_train_batch_size: int = 4
+    per_device_eval_batch_size: int = 4
+    gradient_accumulation_steps: int = 2
+    seq_length: int = 512
+
+    max_steps: int = 1_000_000  # Albert is actually ready after 125000 steps
+    learning_rate: float = 0.00176
+    warmup_steps: int = 5000
+    adam_epsilon: float = 1e-6
+    weight_decay: float = 0.01
+    max_grad_norm: float = 1.0
+    clamp_value: float = 10000.0
+
+    fp16: bool = True
+    fp16_opt_level: str = 'O2'
+    do_train: bool = True
+
+    logging_steps: int = 100
+    save_total_limit: int = 2
+    save_steps: int = 500
+
+    output_dir: str = 'outputs'

+ 151 - 33
examples/albert/run_first_peer.py

@@ -1,49 +1,162 @@
 #!/usr/bin/env python
 
-import argparse
+from dataclasses import dataclass, field, asdict
+import subprocess
 import time
+from typing import Optional
 
-import hivemind
+import torch
+from torch_optimizer import Lamb
+from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 import wandb
-from hivemind.utils.logging import get_logger
-from whatsmyip.ip import get_ip
 from whatsmyip.providers import GoogleDnsProvider
+from whatsmyip.ip import get_ip
 
+from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+import hivemind
+from hivemind.utils.logging import get_logger
 import metrics_utils
 
 
 logger = get_logger(__name__)
 
+
+@dataclass
+class CoordinatorArguments(BaseTrainingArguments):
+    """
+    Note: You might want to have several initial peers so that if one dies,
+    new workers still can join the collaboration via alive initial peers' addresses.
+    Specify initial_peers argument for that purpose
+    """
+    address: Optional[str] = field(
+        default=None,
+        metadata={"help": "This machine's network address. Use public IP for global experiments, "
+                          "local address for private runs"}
+    )
+    refresh_period: float = field(
+        default=30,
+        metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
+    )
+    wandb_project: Optional[str] = field(
+        default=None,
+        metadata={"help": "Learning curves will be published there"}
+    )
+    save_checkpoint_step_interval: int = field(
+        default=5,
+        metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
+    )
+    model_config_path: str = field(
+        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
+        metadata={"help": "Path to the model config"}
+    )
+    repo_path: Optional[str] = field(
+        default=None,
+        metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
+    )
+    upload_interval: Optional[float] = field(
+        default=None,
+        metadata={"help": "Coordinator will upload model once in this many seconds"}
+    )
+
+
+class CheckpointHandler:
+    def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args: CollaborativeOptimizerArguments,
+                 averager_args: AveragerArguments, dht: hivemind.DHT):
+        self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
+        self.repo_path = coordinator_args.repo_path
+        self.upload_interval = coordinator_args.upload_interval
+        self.previous_step = -1
+
+        config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
+        self.model = AlbertForPreTraining(config)
+
+        no_decay = ["bias", "LayerNorm.weight"]
+        optimizer_grouped_parameters = [
+            {
+                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+                "weight_decay": 0.01,
+            },
+            {
+                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+                "weight_decay": 0.0,
+            },
+        ]
+
+        opt = Lamb(
+            optimizer_grouped_parameters,
+            lr=0.00176, weight_decay=0.01, clamp_value=10000.0, debias=True,
+        )
+
+        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
+
+        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
+            opt=opt, dht=dht, prefix=experiment_prefix,
+            compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
+            throughput=collab_optimizer_args.bandwidth,
+            target_batch_size=adjusted_target_batch_size, client_mode=collab_optimizer_args.client_mode,
+            verbose=True, start=True, **asdict(averager_args)
+        )
+        self.previous_timestamp = time.time()
+
+    def is_time_to_save_state(self, cur_step):
+        if self.save_checkpoint_step_interval is None:
+            return False
+        elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
+            return True
+        else:
+            return False
+
+    def save_state(self, cur_step):
+        self.collaborative_optimizer.load_state_from_peers()
+        self.previous_step = cur_step
+
+    def is_time_to_upload(self):
+        if self.repo_path is None:
+            return False
+        elif time.time() - self.previous_timestamp >= self.upload_interval:
+            return True
+        else:
+            return False
+
+    def upload_checkpoint(self, current_loss):
+        self.model.save_pretrained(self.repo_path)
+        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        self.previous_timestamp = time.time()
+        try:
+            subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
+            current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
+            subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
+                           shell=True, check=True, cwd=self.repo_path)
+            subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
+        except subprocess.CalledProcessError as e:
+            logger.warning("Error while uploading model:", e.output)
+
+
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-
-    parser.add_argument('--address', type=str, required=False,
-                        help="this machine's network address. Use public IP for global experiments, "
-                             "local address for private runs.")
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
-    parser.add_argument('--refresh_period', type=float, default=30, required=False,
-                        help="coordinator will fetch keys from DHT once in this many seconds")
-    parser.add_argument('--experiment_prefix', type=str, required=True,
-                        help="a prefix where peers store their metrics for aggregation")
-    parser.add_argument('--wandb_project', type=str, required=True,
-                        help="Weights & Biases project name to publish learning curves")
-
-    args = parser.parse_args()
-    if args.address is None:
+    parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
+    coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+
+    if coordinator_args.address is None:
         logger.warning("No address specified. Attempting to infer address from DNS.")
-        args.address = get_ip(GoogleDnsProvider)
+        coordinator_args.address = get_ip(GoogleDnsProvider)
 
-    validators, local_public_key = metrics_utils.make_validators(args.experiment_prefix)
-    dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*",
+    experiment_prefix = coordinator_args.experiment_prefix
+    validators, local_public_key = metrics_utils.make_validators(experiment_prefix)
+    dht = hivemind.DHT(start=True, listen_on=coordinator_args.dht_listen_on,
+                       endpoint=f"{coordinator_args.address}:*", initial_peers=coordinator_args.initial_peers,
                        record_validators=validators)
-    logger.info(f"Running DHT root at {args.address}:{dht.port}")
 
-    wandb.init(project=args.wandb_project)
+    logger.info(f"Running DHT root at {coordinator_args.address}:{dht.port}")
+
+    if coordinator_args.wandb_project is not None:
+        wandb.init(project=coordinator_args.wandb_project)
+
     current_step = 0
 
+    checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
+
     while True:
-        metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
+        metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
             metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
@@ -63,12 +176,17 @@ if __name__ == '__main__':
                     sum_perf += item.samples_per_second
                     num_samples += item.samples_accumulated
                     sum_mini_steps += item.mini_steps
-                wandb.log({
-                    "loss": sum_loss / sum_mini_steps,
-                    "alive peers": alive_peers,
-                    "samples": num_samples,
-                    "performance": sum_perf
-                })
+                if coordinator_args.wandb_project is not None:
+                    wandb.log({
+                        "loss": sum_loss / sum_mini_steps,
+                        "alive peers": alive_peers,
+                        "samples": num_samples,
+                        "performance": sum_perf
+                    })
+                if checkpoint_handler.is_time_to_save_state(current_step):
+                    checkpoint_handler.save_state(current_step)
+                    if checkpoint_handler.is_time_to_upload():
+                        checkpoint_handler.upload_checkpoint(sum_loss / sum_mini_steps)
                 logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
         logger.debug("Peer is still alive...")
-        time.sleep(args.refresh_period)
+        time.sleep(coordinator_args.refresh_period)

+ 9 - 72
examples/albert/run_trainer.py

@@ -2,11 +2,10 @@
 
 import logging
 import os
-from dataclasses import dataclass, field, asdict
+from dataclasses import asdict
 from pathlib import Path
-from typing import Optional, Dict, Any, List
+from typing import Dict, Any
 
-import hivemind
 import torch
 import transformers
 from datasets import load_from_disk
@@ -18,6 +17,8 @@ from transformers.trainer_utils import is_main_process
 from transformers.trainer import Trainer
 from torch_optimizer import Lamb
 
+import hivemind
+from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments
 import metrics_utils
 
 
@@ -25,75 +26,6 @@ logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
-@dataclass
-class CollaborationArguments:
-    """ define how peers interact with each other while training"""
-
-    # primary parameters
-    initial_peers: List[str]  # one or more peers (comma-separated) that will welcome you into the collaboration
-    experiment_prefix: str  # a unique "name" of this experiment, used to store metadata on the DHT
-    averaging_expiration: float = 5.0  # averaging group will wait for stragglers for at most this many seconds
-    averaging_timeout: float = 30.0  # give up on averaging step after this many seconds
-    target_batch_size: int = 4096  # perform optimizer step after all peers collectively accumulate this many samples
-    client_mode: bool = False  # if True, runs training without incoming connections, in a firewall-compatible mode
-
-    # optional tweaks
-    target_group_size: int = 256  # maximum group size for all-reduce
-    metadata_expiration: float = 30  # peer's metadata will be removed if not updated in this many seconds
-    statistics_expiration: float = 600  # statistics will be removed if not updated in this many seconds
-    dht_listen_on: str = '[::]:*'  # network interface used for incoming DHT communication. Default: all ipv6
-    listen_on: str = '[::]:*'  # network interface used for incoming averager communication. Default: all ipv6
-    endpoint: Optional[str] = None  # this node's IP for inbound connections, used when running from behind a proxy
-    batch_size_lead: int = 0  # optional: begin looking for group in advance, this many samples before target_batch_size
-    compression: str = 'FLOAT16'  # use this compression when averaging parameters/gradients
-
-    min_refresh_period: float = 0.5  # wait for at least this many seconds before fetching new collaboration state
-    max_refresh_period: float = 30  # wait for at most this many seconds before fetching new collaboration state
-    default_refresh_period: float = 3  # attempt to fetch collaboration state every this often until successful
-    expected_drift_peers: float = 3  # trainer assumes that this many new peers can join per step
-    expected_drift_rate: float = 0.2  # trainer assumes that this fraction of current size can join per step
-
-    bandwidth: float = 100.0  # available network bandwidth, in mbps (used for load balancing in all-reduce)
-    performance_ema_alpha: float = 0.1  # uses this alpha for moving average estimate of samples per second
-
-
-@dataclass
-class DatasetArguments:
-    dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
-                                        metadata={"help": "Path to the tokenized dataset"})
-    tokenizer_path: Optional[str] = field(default='./data/tokenizer',
-                                          metadata={"help": "Path to the tokenizer"})
-    config_path: Optional[str] = field(
-        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
-        metadata={"help": "Path to the model config"})
-    cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
-
-
-@dataclass
-class AlbertTrainingArguments(TrainingArguments):
-    dataloader_num_workers: int = 4
-    per_device_train_batch_size: int = 4
-    per_device_eval_batch_size: int = 4
-    gradient_accumulation_steps: int = 2
-    seq_length: int = 512
-
-    max_steps: int = 1_000_000  # Albert is actually ready after 125000 steps
-    learning_rate: float = 0.00176
-    warmup_steps: int = 5000
-    adam_epsilon: float = 1e-6
-    weight_decay: float = 0.01
-    max_grad_norm: float = 1.0
-    clamp_value: float = 10000.0
-
-    fp16: bool = True
-    fp16_opt_level: str = 'O2'
-    do_train: bool = True
-
-    logging_steps: int = 100
-    save_total_limit: int = 2
-    save_steps: int = 500
-
-
 def setup_logging(training_args):
     logging.basicConfig(
         format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
@@ -177,6 +109,11 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.steps = 0
         self.loss = 0
 
+    def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
+                       control: transformers.TrainerControl, **kwargs):
+        logger.warning('Loading state from peers')
+        self.collaborative_optimizer.load_state_from_peers()
+
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
                     control: transformers.TrainerControl, **kwargs):
         control.should_log = True

+ 4 - 1
hivemind/client/averaging/training.py

@@ -187,4 +187,7 @@ def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict,
         elif elem.get('type') == 'value' and 'value' in elem:
             flat_optimizer_state.append(elem['value'])
     with torch.no_grad():
-        return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
+        try:
+            return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
+        except StopIteration:
+            return optimizer

+ 4 - 0
hivemind/optim/collaborative.py

@@ -127,6 +127,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.samples_processed = 0
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -191,6 +192,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
+            self.samples_processed += batch_size
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.should_report_progress.set()
 
@@ -233,6 +235,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.update_scheduler()
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
+            logger.info(f"Your current contribution: {self.samples_processed} samples")
+
             return group_info
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]: