Forráskód Böngészése

get the country of each session

SaulLu 3 éve
szülő
commit
fab0bb12c5
2 módosított fájl, 19 hozzáadás és 0 törlés
  1. 2 0
      arguments.py
  2. 17 0
      lib/training/hf_trainer.py

+ 2 - 0
arguments.py

@@ -36,6 +36,8 @@ class HFTrainerArguments(TrainingArguments):
 
     output_dir: str = "outputs"
 
+    run_country: str = ""
+
     @property
     def batch_size_per_step(self):
         """Compute the number of training sequences contributed by each .step() from this peer"""

+ 17 - 0
lib/training/hf_trainer.py

@@ -1,4 +1,7 @@
 """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
+import json
+import urllib
+
 import torch
 from torch import nn
 from torch.utils.data import DataLoader
@@ -10,6 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger()
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+URL_IP_INFO = "http://ipinfo.io/json"
 
 
 class CollaborativeHFTrainer(Trainer):
@@ -21,6 +25,9 @@ class CollaborativeHFTrainer(Trainer):
     def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
         self.data_seed = data_seed
         self.collaborative_optimizer = collaborative_optimizer
+
+        args = kwargs["args"]
+        setattr(args, "run_country", self.get_country_info())
         super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
 
         if self.fp16_backend is not None:
@@ -37,6 +44,16 @@ class CollaborativeHFTrainer(Trainer):
         return IgnoreGradManipulations(super()._wrap_model(model, training=training),
                                        override_zero_grad=self.collaborative_optimizer.grad_averager.reuse_grad_buffers)
 
+    def get_country_info(self):
+        # As this method is only a nice to have, if ever the command fails for any reason we move on to something else
+        try:
+            response = urllib.request.urlopen(URL_IP_INFO)
+            data = json.load(response)
+            country = data["country"]
+        except Exception:
+            country = ""
+        return country
+
 
 class NoOpScheduler(LRSchedulerBase):
     """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""