run_first_peer.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python
  2. import argparse
  3. import time
  4. import hivemind
  5. import wandb
  6. from hivemind.utils.logging import get_logger
  7. from whatsmyip.ip import get_ip
  8. from whatsmyip.providers import GoogleDnsProvider
  9. import metrics_utils
  10. logger = get_logger(__name__)
  11. if __name__ == '__main__':
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('--address', type=str, required=False,
  14. help="this machine's network address. Use public IP for global experiments, "
  15. "local address for private runs.")
  16. parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
  17. help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
  18. parser.add_argument('--refresh_period', type=float, default=30, required=False,
  19. help="coordinator will fetch keys from DHT once in this many seconds")
  20. parser.add_argument('--experiment_prefix', type=str, required=True,
  21. help="a prefix where peers store their metrics for aggregation")
  22. parser.add_argument('--wandb_project', type=str, required=True,
  23. help="Weights & Biases project name to publish learning curves")
  24. args = parser.parse_args()
  25. if args.address is None:
  26. logger.warning("No address specified. Attempting to infer address from DNS.")
  27. args.address = get_ip(GoogleDnsProvider)
  28. validators, local_public_key = metrics_utils.make_validators(args.experiment_prefix)
  29. dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*",
  30. record_validators=validators)
  31. logger.info(f"Running DHT root at {args.address}:{dht.port}")
  32. wandb.init(project=args.wandb_project)
  33. current_step = 0
  34. while True:
  35. metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
  36. if metrics_dict is not None:
  37. metrics_dict = metrics_dict.value
  38. metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
  39. for peer in metrics_dict]
  40. latest_step = max(item.step for item in metrics)
  41. if latest_step != current_step:
  42. current_step = latest_step
  43. alive_peers = 0
  44. num_batches = 0
  45. sum_loss = 0
  46. num_samples = 0
  47. sum_perf = 0
  48. sum_mini_steps = 0
  49. for item in metrics:
  50. sum_loss += item.loss
  51. alive_peers += 1
  52. sum_perf += item.samples_per_second
  53. num_samples += item.samples_accumulated
  54. sum_mini_steps += item.mini_steps
  55. wandb.log({
  56. "loss": sum_loss / sum_mini_steps,
  57. "alive peers": alive_peers,
  58. "samples": num_samples,
  59. "performance": sum_perf
  60. })
  61. logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
  62. logger.debug("Peer is still alive...")
  63. time.sleep(args.refresh_period)