Michael Diskin 4 år sedan
förälder
incheckning
4d80d99118
2 ändrade filer med 5 tillägg och 1 borttagningar
  1. 1 1
      examples/albert/TPU.py
  2. 4 0
      examples/albert/run_trainer.py

+ 1 - 1
examples/albert/TPU.py

@@ -75,7 +75,7 @@ def main():
     # Patch sys.argv
     sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
 
-    xmp.spawn(mod.main, args=(), nprocs=args.num_cores)
+    xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
 
 
 if __name__ == "__main__":

+ 4 - 0
examples/albert/run_trainer.py

@@ -309,5 +309,9 @@ def main():
         trainer.train(model_path=latest_checkpoint_dir)
 
 
+def _mp_fn(index):
+    # For xla_spawn (TPUs)
+    main()
+
 if __name__ == "__main__":
     main()