run_first_peer.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #!/usr/bin/env python
  2. import time
  3. import argparse
  4. import wandb
  5. from whatsmyip.providers import GoogleDnsProvider
  6. from whatsmyip.ip import get_ip
  7. import hivemind
  8. from hivemind.utils.logging import get_logger
  9. logger = get_logger(__name__)
  10. if __name__ == '__main__':
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--address', type=str, required=False,
  13. help="this machine's network address. Use public IP for global experiments, "
  14. "local address for private runs.")
  15. parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
  16. help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
  17. parser.add_argument('--refresh_period', type=float, default=30, required=False,
  18. help="coordinator will fetch keys from DHT once in this many seconds")
  19. parser.add_argument('--experiment_prefix', type=str, required=True,
  20. help="a prefix where peers store their metrics for aggregation")
  21. parser.add_argument('--wandb_project', type=str, required=True,
  22. help="Weights & Biases project name to publish learning curves")
  23. args = parser.parse_args()
  24. if args.address is None:
  25. logger.warning("No address specified. Attempting to infer address from DNS.")
  26. args.address = get_ip(GoogleDnsProvider)
  27. dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*")
  28. logger.info(f"Running DHT root at {args.address}:{dht.port}")
  29. wandb.init(project=args.wandb_project)
  30. current_step = 0
  31. while True:
  32. metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
  33. if metrics_dict is not None:
  34. metrics_dict = metrics_dict.value
  35. metrics = [metrics_dict[peer].value for peer in metrics_dict]
  36. latest_step = max(metrics)[0]
  37. if latest_step != current_step:
  38. current_step = latest_step
  39. alive_peers = 0
  40. num_batches = 0
  41. sum_loss = 0
  42. num_samples = 0
  43. sum_perf = 0
  44. sum_mini_steps = 0
  45. for step, perf, samples, loss, mini_steps in metrics:
  46. sum_loss += loss
  47. alive_peers += 1
  48. sum_perf += perf
  49. num_samples += samples
  50. sum_mini_steps += mini_steps
  51. wandb.log({
  52. "loss": sum_loss / sum_mini_steps,
  53. "alive peers": alive_peers,
  54. "samples": num_samples,
  55. "performance": sum_perf
  56. })
  57. logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
  58. logger.debug("Peer is still alive...")
  59. time.sleep(args.refresh_period)