فهرست منبع

Fix hivemind.Optimizer usage and minor issues in examples/albert (#433)

Resolves #431, the 1st issue from #387, and many other minor issues (see the PR's comments).
Alexander Borzunov 3 سال پیش
والد
کامیت
0ea81fa43b

+ 42 - 39
examples/albert/run_trainer.py

@@ -1,7 +1,8 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import os
 import pickle
+import sys
 from dataclasses import asdict
 from pathlib import Path
 
@@ -58,36 +59,6 @@ def get_model(training_args, config, tokenizer):
     return model
 
 
-def get_optimizer_and_scheduler(training_args, model):
-    no_decay = ["bias", "LayerNorm.weight"]
-    optimizer_grouped_parameters = [
-        {
-            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-            "weight_decay": training_args.weight_decay,
-        },
-        {
-            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
-            "weight_decay": 0.0,
-        },
-    ]
-
-    opt = Lamb(
-        optimizer_grouped_parameters,
-        lr=training_args.learning_rate,
-        betas=(training_args.adam_beta1, training_args.adam_beta2),
-        eps=training_args.adam_epsilon,
-        weight_decay=training_args.weight_decay,
-        clamp_value=training_args.clamp_value,
-        debias=True,
-    )
-
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
-    )
-
-    return opt, scheduler
-
-
 class CollaborativeCallback(transformers.TrainerCallback):
     """
     This callback monitors and reports collaborative training progress.
@@ -149,9 +120,9 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 )
                 logger.info(f"Step #{self.optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
-                logger.info(f"Performance: {samples_per_second} samples per second.")
+                logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
                 if self.steps:
-                    logger.info(f"Local loss: {self.loss / self.steps}")
+                    logger.info(f"Local loss: {self.loss / self.steps:.5f}")
                 if self.optimizer.local_epoch % self.backup_every_steps == 0:
                     self.latest_backup = self.backup_state()
 
@@ -219,10 +190,7 @@ def main():
         )
     )
     training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
-
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
-    if len(collaboration_args.initial_peers) == 0:
-        raise ValueError("Please specify at least one network endpoint in initial peers.")
 
     setup_transformers_logging(training_args.local_rank)
     logger.info(f"Training/evaluation parameters:\n{training_args}")
@@ -231,7 +199,15 @@ def main():
     set_seed(training_args.seed)
 
     config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
-    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    try:
+        tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    except OSError:
+        logger.fatal(
+            f"No tokenizer data found in {dataset_args.tokenizer_path}, "
+            f"please run ./tokenize_wikitext103.py before running this"
+        )
+        sys.exit(1)
+
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
 
@@ -239,8 +215,6 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    opt, scheduler = get_optimizer_and_scheduler(training_args, model)
-
     validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
 
     dht = DHT(
@@ -261,12 +235,41 @@ def main():
 
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
+    # We need to make such a lambda function instead of just an optimizer instance
+    # to make hivemind.Optimizer(..., offload_optimizer=True) work
+    opt = lambda params: Lamb(
+        params,
+        lr=training_args.learning_rate,
+        betas=(training_args.adam_beta1, training_args.adam_beta2),
+        eps=training_args.adam_epsilon,
+        weight_decay=training_args.weight_decay,
+        clamp_value=training_args.clamp_value,
+        debias=True,
+    )
+
+    no_decay = ["bias", "LayerNorm.weight"]
+    params = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": training_args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+
+    scheduler = lambda opt: get_linear_schedule_with_warmup(
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+    )
+
     optimizer = Optimizer(
         dht=dht,
         run_id=collaboration_args.experiment_prefix,
         target_batch_size=adjusted_target_batch_size,
         batch_size_per_step=total_batch_size_per_step,
         optimizer=opt,
+        params=params,
         scheduler=scheduler,
         matchmaking_time=collaboration_args.matchmaking_time,
         averaging_timeout=collaboration_args.averaging_timeout,

+ 1 - 1
examples/albert/run_training_monitor.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import time
 from dataclasses import asdict, dataclass, field

+ 1 - 1
examples/albert/tokenize_wikitext103.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 """ This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
 import random
 from functools import partial

+ 2 - 1
hivemind/optim/optimizer.py

@@ -197,6 +197,8 @@ class Optimizer(torch.optim.Optimizer):
         shutdown_timeout: float = 5,
         verbose: bool = False,
     ):
+        self._parent_pid = os.getpid()
+
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
@@ -262,7 +264,6 @@ class Optimizer(torch.optim.Optimizer):
 
         self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._schema_hash = self._compute_schema_hash()
-        self._parent_pid = os.getpid()
 
         self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
         # measures the average time from the beginning of self._update_global_epoch to the call to state_averager