Selaa lähdekoodia

Merge branch 'TPU'

Michael Diskin 4 vuotta sitten
vanhempi
commit
2d6a502927

+ 82 - 0
examples/albert/TPU.py

@@ -0,0 +1,82 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A simple launcher script for TPU training
+Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
+::
+    >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
+               YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
+               arguments of your training script)
+"""
+
+
+import importlib
+import sys
+from argparse import REMAINDER, ArgumentParser
+from pathlib import Path
+
+import torch_xla.distributed.xla_multiprocessing as xmp
+
+
+def parse_args():
+    """
+    Helper function parsing the command line options
+    @retval ArgumentParser
+    """
+    parser = ArgumentParser(
+        description=(
+            "PyTorch TPU distributed training launch "
+            "helper utility that will spawn up "
+            "multiple distributed processes"
+        )
+    )
+
+    # Optional arguments for the launch helper
+    parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
+
+    # positional
+    parser.add_argument(
+        "training_script",
+        type=str,
+        help=(
+            "The full path to the single TPU training "
+            "program/script to be launched in parallel, "
+            "followed by all the arguments for the "
+            "training script"
+        ),
+    )
+
+    # rest from the training program
+    parser.add_argument("training_script_args", nargs=REMAINDER)
+
+    return parser.parse_args()
+
+
+def main():
+    args = parse_args()
+
+    # Import training_script as a module.
+    script_fpath = Path(args.training_script)
+    sys.path.append(str(script_fpath.parent.resolve()))
+    mod_name = script_fpath.stem
+    mod = importlib.import_module(mod_name)
+
+    # Patch sys.argv
+    sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
+
+    xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
+
+
+if __name__ == "__main__":
+    main()

+ 12 - 11
examples/albert/run_trainer.py

@@ -16,6 +16,8 @@ from transformers.models.albert import AlbertConfig, AlbertForPreTraining, Alber
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
+from transformers.optimization import Adafactor, AdafactorSchedule
+
 
 import hivemind
 from hivemind.utils.compression import CompressionType
@@ -77,18 +79,13 @@ def get_optimizer_and_scheduler(training_args, model):
         },
     ]
 
-    opt = Lamb(
+    opt = Adafactor(
         optimizer_grouped_parameters,
-        lr=training_args.learning_rate,
-        betas=(training_args.adam_beta1, training_args.adam_beta2),
-        eps=training_args.adam_epsilon,
-        weight_decay=training_args.weight_decay,
-        clamp_value=training_args.clamp_value,
-        debias=True,
+        scale_parameter=True, relative_step=True, warmup_init=True, lr=None
     )
 
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+    scheduler = AdafactorSchedule(
+        opt
     )
 
     return opt, scheduler
@@ -219,8 +216,6 @@ def main():
     training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
 
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
-    if len(collaboration_args.initial_peers) == 0:
-        raise ValueError("Please specify at least one network endpoint in initial peers.")
 
     setup_logging(training_args)
 
@@ -231,6 +226,7 @@ def main():
     tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
+    model.tie_weights()
 
     tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
     # This data collator will take care of randomly masking the tokens.
@@ -310,5 +306,10 @@ def main():
         trainer.train(model_path=latest_checkpoint_dir)
 
 
+def _mp_fn(index):
+    # For xla_spawn (TPUs)
+    main()
+
+
 if __name__ == "__main__":
     main()

+ 10 - 1
hivemind/averaging/averager.py

@@ -325,11 +325,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
+        logger.debug(f"mer 0")
         future = MPFuture()
+        logger.debug(f"mer 1")
         gather_binary = self.serializer.dumps(
             gather
         )  # serialize here to avoid loading modules in the averager process
+        logger.debug(f"mer 2")
         self._outer_pipe.send(
             (
                 "_step",
@@ -343,21 +345,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 ),
             )
         )
+        logger.debug(f"mer 5")
         return future.result() if wait else future
 
     async def _step(
         self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
     ):
+        logger.debug(f"be 0")
         start_time = get_dht_time()
 
         try:
             while not future.done():
                 try:
+                    logger.debug(f"be 1")
                     self._pending_group_assembled.clear()
+                    logger.debug(f"be 2")
+
                     data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
+                    logger.debug(f"be 3")
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                     )
+                    logger.debug(f"be 4")
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 

+ 6 - 1
hivemind/averaging/training.py

@@ -68,17 +68,20 @@ class TrainingAverager(DecentralizedAverager):
         """
         if not wait:
             return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
+        logger.debug(f"AAAAA 1")
 
         # if data_lock is supplied, tensors might change during averaging, so we need to copy them
         use_old_local_tensors = data_lock is not None
         if data_lock is None:
             data_lock = nullcontext()
-
+        logger.debug(f"AAAAA 2")
         local_tensors = list(self.local_tensors())
+        logger.debug(f"AAAAA 3")
         with self.lock_averager_step, torch.no_grad():
             # fill averager's tensors with current local tensors
             self.pending_updates_done.clear()
             with data_lock, self.get_tensors() as averaged_tensors:
+                logger.debug(f"AAAAA 4")
                 if use_old_local_tensors:
                     old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
                 assert len(local_tensors) == len(
@@ -86,7 +89,9 @@ class TrainingAverager(DecentralizedAverager):
                 ), "The number of optimized parameters should not change."
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
+            logger.debug(f"AAAAA 5")
             self.pending_updates_done.set()
+            logger.debug(f"AAAAA 84")
 
             # find a group and hopefully average tensors with peers, use batch sizes as weights
             gathered = super().step(**kwargs)

+ 7 - 1
hivemind/optim/collaborative.py

@@ -265,14 +265,20 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     self.status_loglevel,
                     f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
                 )
-
+            logger.log(self.status_loglevel, f"0")
             self.opt.step()
+            logger.log(self.status_loglevel, f"1")
+
             self.reset_accumulated_grads_()
+            logger.log(self.status_loglevel, f"2")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
+            logger.log(self.status_loglevel, f"3")
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
+            logger.log(self.status_loglevel, f"4")
             self.update_scheduler()
+            logger.log(self.status_loglevel, f"5")
 
         logger.log(self.status_loglevel, f"Optimizer step: done!")