Prechádzať zdrojové kódy

Fix monitor state_averager

Yi Zhou 3 rokov pred
rodič
commit
e7ac8de30f

+ 1 - 1
examples/albert/arguments.py

@@ -6,7 +6,7 @@ from transformers import TrainingArguments
 
 @dataclass
 class BaseTrainingArguments:
-    experiment_prefix: str = field(
+    run_id: str = field(
         default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
     )
     initial_peers: List[str] = field(

+ 2 - 2
examples/albert/run_trainer.py

@@ -215,7 +215,7 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
+    validators, local_public_key = utils.make_validators(collaboration_args.run_id)
 
     dht = DHT(
         start=True,
@@ -265,7 +265,7 @@ def main():
 
     optimizer = Optimizer(
         dht=dht,
-        run_id=collaboration_args.experiment_prefix,
+        run_id=collaboration_args.run_id,
         target_batch_size=adjusted_target_batch_size,
         batch_size_per_step=total_batch_size_per_step,
         optimizer=opt,

+ 4 - 4
examples/albert/run_training_monitor.py

@@ -99,7 +99,7 @@ class CheckpointHandler:
         self.state_averager = TrainingStateAverager(
             dht=dht,
             optimizer=opt,
-            prefix=experiment_prefix,
+            prefix=f"{run_id}_state_averager",
             state_compression=hivemind.Float16Compression(),
             bandwidth=optimizer_args.bandwidth,
             client_mode=optimizer_args.client_mode,
@@ -155,8 +155,8 @@ if __name__ == "__main__":
         version = ip_address(address).version
         monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
-    experiment_prefix = monitor_args.experiment_prefix
-    validators, local_public_key = utils.make_validators(experiment_prefix)
+    run_id = monitor_args.run_id
+    validators, local_public_key = utils.make_validators(run_id)
 
     dht = hivemind.DHT(
         start=True,
@@ -177,7 +177,7 @@ if __name__ == "__main__":
         checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
     while True:
-        metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
+        metrics_dict = dht.get(run_id + "_metrics", latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
             metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]

+ 2 - 2
examples/albert/utils.py

@@ -24,9 +24,9 @@ class MetricSchema(BaseModel):
     metrics: Dict[BytesWithPublicKey, LocalMetrics]
 
 
-def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
-    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
+    validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
     return validators, signature_validator.local_public_key