1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- #!/usr/bin/env python
- import argparse
- import time
- import hivemind
- import wandb
- from hivemind.utils.logging import get_logger
- from whatsmyip.ip import get_ip
- from whatsmyip.providers import GoogleDnsProvider
- import metrics_utils
- logger = get_logger(__name__)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--address', type=str, required=False,
- help="this machine's network address. Use public IP for global experiments, "
- "local address for private runs.")
- parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
- help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
- parser.add_argument('--refresh_period', type=float, default=30, required=False,
- help="coordinator will fetch keys from DHT once in this many seconds")
- parser.add_argument('--experiment_prefix', type=str, required=True,
- help="a prefix where peers store their metrics for aggregation")
- parser.add_argument('--wandb_project', type=str, required=True,
- help="Weights & Biases project name to publish learning curves")
- args = parser.parse_args()
- if args.address is None:
- logger.warning("No address specified. Attempting to infer address from DNS.")
- args.address = get_ip(GoogleDnsProvider)
- validators, local_public_key = metrics_utils.make_validators(args.experiment_prefix)
- dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*",
- record_validators=validators)
- logger.info(f"Running DHT root at {args.address}:{dht.port}")
- wandb.init(project=args.wandb_project)
- current_step = 0
- while True:
- metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
- if metrics_dict is not None:
- metrics_dict = metrics_dict.value
- metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
- for peer in metrics_dict]
- latest_step = max(item.step for item in metrics)
- if latest_step != current_step:
- current_step = latest_step
- alive_peers = 0
- num_batches = 0
- sum_loss = 0
- num_samples = 0
- sum_perf = 0
- sum_mini_steps = 0
- for item in metrics:
- sum_loss += item.loss
- alive_peers += 1
- sum_perf += item.samples_per_second
- num_samples += item.samples_accumulated
- sum_mini_steps += item.mini_steps
- wandb.log({
- "loss": sum_loss / sum_mini_steps,
- "alive peers": alive_peers,
- "samples": num_samples,
- "performance": sum_perf
- })
- logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
- logger.debug("Peer is still alive...")
- time.sleep(args.refresh_period)
|