ppo.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import argparse
  2. from importlib.resources import path
  3. import pathlib
  4. import torch
  5. import hivemind
  6. from hivemind import Float16Compression, SizeAdaptiveCompression, Uniform8BitQuantization
  7. from stable_baselines3 import PPO
  8. from stable_baselines3.common.env_util import make_atari_env
  9. from stable_baselines3.common.vec_env import VecFrameStack
  10. def parse_args():
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--n-steps', type=int, default=128, help='Number of rollout steps per each agent')
  13. parser.add_argument('--n-envs', type=int, default=8, help='Number of training envs')
  14. parser.add_argument('--batch-size', type=int, default=256)
  15. parser.add_argument('--target-batch-size', type=int, default=32768)
  16. parser.add_argument('--n-epochs', type=int, default=1, help='Number of training epochs per each rollout')
  17. parser.add_argument('--learning-rate', type=float, default=2.5e-4)
  18. parser.add_argument('--tb-logs-path', type=pathlib.Path, default='./logs', help='Path to tensorboard logs folder')
  19. parser.add_argument('--experiment-prefix', type=str, help='Experiment prefix for tensorboard logs')
  20. parser.add_argument('--initial-peers', nargs='+', default=[])
  21. parser.add_argument('--averaging-compression', action='store_true')
  22. args = parser.parse_args()
  23. return args
  24. def generate_experiment_name(args):
  25. exp_name_dict = {
  26. 'bs': args.batch_size,
  27. 'target_bs': args.target_batch_size,
  28. 'n_envs': args.n_envs,
  29. 'n_steps': args.n_steps,
  30. 'n_epochs': args.n_epochs,
  31. }
  32. exp_name = [f'{key}-{value}' for key, value in exp_name_dict.items()]
  33. exp_name = '.'.join(exp_name)
  34. if args.experiment_prefix:
  35. exp_name = f'{args.experiment_prefix}.{exp_name}'
  36. exp_name = exp_name.replace('000.', 'k.')
  37. return exp_name
  38. class AdamWithClipping(torch.optim.Adam):
  39. def __init__(self, *args, max_grad_norm: float, **kwargs):
  40. self.max_grad_norm = max_grad_norm
  41. super().__init__(*args, **kwargs)
  42. def step(self, *args, **kwargs):
  43. iter_params = (param for group in self.param_groups for param in group["params"])
  44. torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
  45. return super().step(*args, **kwargs)
  46. def configure_dht_opts(args):
  47. opts = {
  48. 'start': True,
  49. }
  50. if args.initial_peers:
  51. opts['initial_peers'] = args.initial_peers
  52. return opts
  53. if __name__ == "__main__":
  54. args = parse_args()
  55. dht_opts = configure_dht_opts(args)
  56. dht = hivemind.DHT(**dht_opts)
  57. print("To connect other peers to this one, use --initial_peers", *[str(addr) for addr in dht.get_visible_maddrs()])
  58. env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=args.n_envs)
  59. env = VecFrameStack(env, n_stack=4)
  60. model = PPO(
  61. 'CnnPolicy', env,
  62. verbose=1,
  63. batch_size=args.batch_size,
  64. n_steps=args.n_steps,
  65. n_epochs=args.n_epochs,
  66. learning_rate=args.learning_rate,
  67. clip_range=0.1,
  68. vf_coef=0.5,
  69. ent_coef=0.01,
  70. tensorboard_log=args.tb_logs_path,
  71. max_grad_norm=10000.0,
  72. policy_kwargs={'optimizer_class': AdamWithClipping, 'optimizer_kwargs': {'max_grad_norm': 0.5}}
  73. )
  74. compression_opts = {}
  75. if args.averaging_compression:
  76. averaging_compression = SizeAdaptiveCompression(
  77. threshold=2 ** 10 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization()
  78. )
  79. compression_opts.update({
  80. 'grad_compression': averaging_compression,
  81. 'state_averaging_compression': averaging_compression
  82. })
  83. model.policy.optimizer_class = hivemind.Optimizer
  84. model.policy.optimizer = hivemind.Optimizer(
  85. dht=dht,
  86. optimizer=model.policy.optimizer,
  87. run_id='ppo_hivemind',
  88. batch_size_per_step=args.batch_size,
  89. target_batch_size=args.target_batch_size,
  90. offload_optimizer=False,
  91. verbose=True,
  92. use_local_updates=False,
  93. matchmaking_time=4,
  94. averaging_timeout=15,
  95. **compression_opts,
  96. )
  97. model.policy.optimizer.load_state_from_peers()
  98. model.learn(total_timesteps=int(5e11), tb_log_name=generate_experiment_name(args))