|
@@ -1,4 +1,7 @@
|
|
|
"""A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
|
|
|
+import json
|
|
|
+import urllib
|
|
|
+
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.utils.data import DataLoader
|
|
@@ -10,6 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger()
|
|
|
LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
|
|
|
+URL_IP_INFO = "http://ipinfo.io/json"
|
|
|
|
|
|
|
|
|
class CollaborativeHFTrainer(Trainer):
|
|
@@ -21,6 +25,9 @@ class CollaborativeHFTrainer(Trainer):
|
|
|
def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs):
|
|
|
self.data_seed = data_seed
|
|
|
self.collaborative_optimizer = collaborative_optimizer
|
|
|
+
|
|
|
+ args = kwargs["args"]
|
|
|
+ setattr(args, "run_country", self.get_country_info())
|
|
|
super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs)
|
|
|
|
|
|
if self.fp16_backend is not None:
|
|
@@ -37,6 +44,16 @@ class CollaborativeHFTrainer(Trainer):
|
|
|
return IgnoreGradManipulations(super()._wrap_model(model, training=training),
|
|
|
override_zero_grad=self.collaborative_optimizer.grad_averager.reuse_grad_buffers)
|
|
|
|
|
|
+ def get_country_info(self):
|
|
|
+ # As this method is only a nice to have, if ever the command fails for any reason we move on to something else
|
|
|
+ try:
|
|
|
+ response = urllib.request.urlopen(URL_IP_INFO)
|
|
|
+ data = json.load(response)
|
|
|
+ country = data["country"]
|
|
|
+ except Exception:
|
|
|
+ country = ""
|
|
|
+ return country
|
|
|
+
|
|
|
|
|
|
class NoOpScheduler(LRSchedulerBase):
|
|
|
"""Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
|