|
@@ -99,7 +99,7 @@ class CheckpointHandler:
|
|
|
self.state_averager = TrainingStateAverager(
|
|
|
dht=dht,
|
|
|
optimizer=opt,
|
|
|
- prefix=experiment_prefix,
|
|
|
+ prefix=f"{run_id}_state_averager",
|
|
|
state_compression=hivemind.Float16Compression(),
|
|
|
bandwidth=optimizer_args.bandwidth,
|
|
|
client_mode=optimizer_args.client_mode,
|
|
@@ -155,8 +155,8 @@ if __name__ == "__main__":
|
|
|
version = ip_address(address).version
|
|
|
monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
|
|
|
|
|
|
- experiment_prefix = monitor_args.experiment_prefix
|
|
|
- validators, local_public_key = utils.make_validators(experiment_prefix)
|
|
|
+ run_id = monitor_args.run_id
|
|
|
+ validators, local_public_key = utils.make_validators(run_id)
|
|
|
|
|
|
dht = hivemind.DHT(
|
|
|
start=True,
|
|
@@ -177,7 +177,7 @@ if __name__ == "__main__":
|
|
|
checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
|
|
|
|
|
|
while True:
|
|
|
- metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
|
|
|
+ metrics_dict = dht.get(run_id + "_metrics", latest=True)
|
|
|
if metrics_dict is not None:
|
|
|
metrics_dict = metrics_dict.value
|
|
|
metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
|