123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import argparse
- from importlib.resources import path
- import pathlib
- import torch
- import hivemind
- from hivemind import Float16Compression, SizeAdaptiveCompression, Uniform8BitQuantization
- from stable_baselines3 import PPO
- from stable_baselines3.common.env_util import make_atari_env
- from stable_baselines3.common.vec_env import VecFrameStack
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--n-steps', type=int, default=128, help='Number of rollout steps per each agent')
- parser.add_argument('--n-envs', type=int, default=8, help='Number of training envs')
- parser.add_argument('--batch-size', type=int, default=256)
- parser.add_argument('--target-batch-size', type=int, default=32768)
- parser.add_argument('--n-epochs', type=int, default=1, help='Number of training epochs per each rollout')
- parser.add_argument('--learning-rate', type=float, default=2.5e-4)
- parser.add_argument('--tb-logs-path', type=pathlib.Path, default='./logs', help='Path to tensorboard logs folder')
- parser.add_argument('--experiment-prefix', type=str, help='Experiment prefix for tensorboard logs')
- parser.add_argument('--initial-peers', nargs='+', default=[])
- parser.add_argument('--averaging-compression', action='store_true')
- args = parser.parse_args()
- return args
- def generate_experiment_name(args):
- exp_name_dict = {
- 'bs': args.batch_size,
- 'target_bs': args.target_batch_size,
- 'n_envs': args.n_envs,
- 'n_steps': args.n_steps,
- 'n_epochs': args.n_epochs,
- }
- exp_name = [f'{key}-{value}' for key, value in exp_name_dict.items()]
- exp_name = '.'.join(exp_name)
- if args.experiment_prefix:
- exp_name = f'{args.experiment_prefix}.{exp_name}'
- exp_name = exp_name.replace('000.', 'k.')
- return exp_name
- class AdamWithClipping(torch.optim.Adam):
- def __init__(self, *args, max_grad_norm: float, **kwargs):
- self.max_grad_norm = max_grad_norm
- super().__init__(*args, **kwargs)
- def step(self, *args, **kwargs):
- iter_params = (param for group in self.param_groups for param in group["params"])
- torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
- return super().step(*args, **kwargs)
- def configure_dht_opts(args):
- opts = {
- 'start': True,
- }
- if args.initial_peers:
- opts['initial_peers'] = args.initial_peers
- return opts
- if __name__ == "__main__":
- args = parse_args()
- dht_opts = configure_dht_opts(args)
- dht = hivemind.DHT(**dht_opts)
- print("To connect other peers to this one, use --initial_peers", *[str(addr) for addr in dht.get_visible_maddrs()])
- env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=args.n_envs)
- env = VecFrameStack(env, n_stack=4)
- model = PPO(
- 'CnnPolicy', env,
- verbose=1,
- batch_size=args.batch_size,
- n_steps=args.n_steps,
- n_epochs=args.n_epochs,
- learning_rate=args.learning_rate,
- clip_range=0.1,
- vf_coef=0.5,
- ent_coef=0.01,
- tensorboard_log=args.tb_logs_path,
- max_grad_norm=10000.0,
- policy_kwargs={'optimizer_class': AdamWithClipping, 'optimizer_kwargs': {'max_grad_norm': 0.5}}
- )
- compression_opts = {}
- if args.averaging_compression:
- averaging_compression = SizeAdaptiveCompression(
- threshold=2 ** 10 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization()
- )
- compression_opts.update({
- 'grad_compression': averaging_compression,
- 'state_averaging_compression': averaging_compression
- })
- model.policy.optimizer_class = hivemind.Optimizer
- model.policy.optimizer = hivemind.Optimizer(
- dht=dht,
- optimizer=model.policy.optimizer,
- run_id='ppo_hivemind',
- batch_size_per_step=args.batch_size,
- target_batch_size=args.target_batch_size,
- offload_optimizer=False,
- verbose=True,
- use_local_updates=False,
- matchmaking_time=4,
- averaging_timeout=15,
- **compression_opts,
- )
- model.policy.optimizer.load_state_from_peers()
- model.learn(total_timesteps=int(5e11), tb_log_name=generate_experiment_name(args))
|