|
@@ -75,7 +75,7 @@ def main():
|
|
|
# Patch sys.argv
|
|
|
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
|
|
|
|
|
|
- xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
|
|
|
+ xmp.spawn(mod.main, args=(), nprocs=args.num_cores)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|