Ver código fonte

Merge remote-tracking branch 'origin/master' into usability_tweaks

justheuristic 3 anos atrás
pai
commit
77be4960bb

+ 4 - 0
README.md

@@ -12,6 +12,10 @@ large model on hundreds of computers from different universities, companies, and
 
 
 ![img](https://i.imgur.com/GPxolxb.gif)
 ![img](https://i.imgur.com/GPxolxb.gif)
 
 
+## Live Demo
+
+Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](https://training-transformers-together.github.io/) to see hivemind in action, join an ongoing collaborative experiment, and learn more about the technologies behind it!
+
 ## Key Features
 ## Key Features
 
 
 * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized
 * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized

+ 1 - 2
benchmarks/benchmark_optimizer.py

@@ -6,7 +6,6 @@ from dataclasses import dataclass
 from functools import partial
 from functools import partial
 from typing import Callable
 from typing import Callable
 
 
-import numpy as np
 import torch
 import torch
 import torchvision
 import torchvision
 from torch import nn as nn
 from torch import nn as nn
@@ -14,7 +13,7 @@ from torch.nn import functional as F
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
 import hivemind
 import hivemind
-from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 
 

+ 2 - 2
docs/modules/optim.rst

@@ -9,8 +9,8 @@
 
 
   <br><br>
   <br><br>
 
 
-.. automodule:: hivemind.optim.experimental.optimizer
-.. currentmodule:: hivemind.optim.experimental.optimizer
+.. automodule:: hivemind.optim.optimizer
+.. currentmodule:: hivemind.optim.optimizer
 
 
 **hivemind.Optimizer**
 **hivemind.Optimizer**
 ----------------------
 ----------------------

+ 13 - 12
examples/albert/arguments.py

@@ -45,12 +45,11 @@ class BaseTrainingArguments:
 
 
 @dataclass
 @dataclass
 class AveragerArguments:
 class AveragerArguments:
-    averaging_expiration: float = field(
-        default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
-    )
-    averaging_timeout: float = field(
-        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
-    )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
+
+
+@dataclass
+class ProgressTrackerArguments:
     min_refresh_period: float = field(
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
     )
@@ -66,17 +65,13 @@ class AveragerArguments:
     expected_drift_rate: float = field(
     expected_drift_rate: float = field(
         default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
         default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
     )
     )
-    performance_ema_alpha: float = field(
-        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
     metadata_expiration: float = field(
         default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
         default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
     )
 
 
 
 
 @dataclass
 @dataclass
-class CollaborativeOptimizerArguments:
+class OptimizerArguments:
     target_batch_size: int = field(
     target_batch_size: int = field(
         default=4096,
         default=4096,
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
@@ -93,10 +88,16 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
     )
+    averaging_timeout: float = field(
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    matchmaking_time: float = field(
+        default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"}
+    )
 
 
 
 
 @dataclass
 @dataclass
-class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
+class CollaborationArguments(OptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
     )

+ 55 - 35
examples/albert/run_trainer.py

@@ -16,11 +16,17 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
 
 
-import hivemind
+from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
 import utils
 import utils
-from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
+from arguments import (
+    AlbertTrainingArguments,
+    AveragerArguments,
+    CollaborationArguments,
+    DatasetArguments,
+    ProgressTrackerArguments,
+)
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -90,8 +96,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        dht: hivemind.DHT,
-        optimizer: hivemind.CollaborativeOptimizer,
+        dht: DHT,
+        optimizer: Optimizer,
         model: torch.nn.Module,
         model: torch.nn.Module,
         local_public_key: bytes,
         local_public_key: bytes,
         statistics_expiration: float,
         statistics_expiration: float,
@@ -99,7 +105,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
     ):
         super().__init__()
         super().__init__()
         self.model = model
         self.model = model
-        self.dht, self.collaborative_optimizer = dht, optimizer
+        self.dht, self.optimizer = dht, optimizer
         self.local_public_key = local_public_key
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
         self.last_reported_collaboration_step = -1
@@ -114,7 +120,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
     ):
     ):
         logger.info("Loading state from peers")
         logger.info("Loading state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.optimizer.load_state_from_peers()
 
 
     def on_step_end(
     def on_step_end(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -124,40 +130,43 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.restore_from_backup(self.latest_backup)
             self.restore_from_backup(self.latest_backup)
             return control
             return control
 
 
+        local_progress = self.optimizer.local_progress
+
         if state.log_history:
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.loss += state.log_history[-1]["loss"]
             self.steps += 1
             self.steps += 1
-            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
-                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+
+            if self.optimizer.local_epoch != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.optimizer.local_epoch
                 self.total_samples_processed += self.samples
                 self.total_samples_processed += self.samples
-                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                samples_per_second = local_progress.samples_per_second
                 statistics = utils.LocalMetrics(
                 statistics = utils.LocalMetrics(
-                    step=self.collaborative_optimizer.local_step,
+                    step=self.optimizer.local_epoch,
                     samples_per_second=samples_per_second,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     loss=self.loss,
                     mini_steps=self.steps,
                     mini_steps=self.steps,
                 )
                 )
-                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
                 if self.steps:
                     logger.info(f"Local loss: {self.loss / self.steps}")
                     logger.info(f"Local loss: {self.loss / self.steps}")
-                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                if self.optimizer.local_epoch % self.backup_every_steps == 0:
                     self.latest_backup = self.backup_state()
                     self.latest_backup = self.backup_state()
 
 
                 self.loss = 0
                 self.loss = 0
                 self.steps = 0
                 self.steps = 0
-                if self.collaborative_optimizer.is_synchronized:
+                if self.optimizer.is_synchronized_with_peers():
                     self.dht.store(
                     self.dht.store(
-                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        key=self.optimizer.run_id + "_metrics",
                         subkey=self.local_public_key,
                         subkey=self.local_public_key,
                         value=statistics.dict(),
                         value=statistics.dict(),
-                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        expiration_time=get_dht_time() + self.statistics_expiration,
                         return_future=True,
                         return_future=True,
                     )
                     )
 
 
-        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.samples = local_progress.samples_accumulated
 
 
         return control
         return control
 
 
@@ -170,19 +179,17 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def backup_state(self) -> bytes:
     def backup_state(self) -> bytes:
-        return pickle.dumps(
-            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
-        )
+        return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
 
 
     @torch.no_grad()
     @torch.no_grad()
     def restore_from_backup(self, backup: bytes):
     def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
+        self.optimizer.load_state_dict(state["optimizer"])
 
 
 
 
 class NoOpScheduler(LRSchedulerBase):
 class NoOpScheduler(LRSchedulerBase):
-    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
 
 
     def get_lr(self):
     def get_lr(self):
         return [group["lr"] for group in self.optimizer.param_groups]
         return [group["lr"] for group in self.optimizer.param_groups]
@@ -202,8 +209,16 @@ class NoOpScheduler(LRSchedulerBase):
 
 
 
 
 def main():
 def main():
-    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
-    training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser(
+        (
+            AlbertTrainingArguments,
+            DatasetArguments,
+            CollaborationArguments,
+            AveragerArguments,
+            ProgressTrackerArguments,
+        )
+    )
+    training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
 
 
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
     if len(collaboration_args.initial_peers) == 0:
     if len(collaboration_args.initial_peers) == 0:
@@ -228,7 +243,7 @@ def main():
 
 
     validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
     validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
 
 
-    dht = hivemind.DHT(
+    dht = DHT(
         start=True,
         start=True,
         initial_peers=collaboration_args.initial_peers,
         initial_peers=collaboration_args.initial_peers,
         client_mode=collaboration_args.client_mode,
         client_mode=collaboration_args.client_mode,
@@ -246,19 +261,24 @@ def main():
 
 
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
 
-    collaborative_optimizer = hivemind.CollaborativeOptimizer(
-        opt=opt,
+    optimizer = Optimizer(
         dht=dht,
         dht=dht,
-        scheduler=scheduler,
-        prefix=collaboration_args.experiment_prefix,
-        compression=hivemind.Float16Compression(),
-        batch_size_per_step=total_batch_size_per_step,
-        bandwidth=collaboration_args.bandwidth,
+        run_id=collaboration_args.experiment_prefix,
         target_batch_size=adjusted_target_batch_size,
         target_batch_size=adjusted_target_batch_size,
+        batch_size_per_step=total_batch_size_per_step,
+        optimizer=opt,
+        scheduler=scheduler,
+        matchmaking_time=collaboration_args.matchmaking_time,
+        averaging_timeout=collaboration_args.averaging_timeout,
+        offload_optimizer=True,
+        delay_optimizer_step=True,
+        delay_grad_averaging=True,
         client_mode=collaboration_args.client_mode,
         client_mode=collaboration_args.client_mode,
+        grad_compression=Float16Compression(),
+        state_averaging_compression=Float16Compression(),
+        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
+        tracker_opts=asdict(tracker_args),
         verbose=True,
         verbose=True,
-        start=True,
-        **asdict(averager_args),
     )
     )
 
 
     class TrainerWithIndependentShuffling(Trainer):
     class TrainerWithIndependentShuffling(Trainer):
@@ -274,11 +294,11 @@ def main():
         data_collator=data_collator,
         data_collator=data_collator,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
-        optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
+        optimizers=(optimizer, NoOpScheduler(optimizer)),
         callbacks=[
         callbacks=[
             CollaborativeCallback(
             CollaborativeCallback(
                 dht,
                 dht,
-                collaborative_optimizer,
+                optimizer,
                 model,
                 model,
                 local_public_key,
                 local_public_key,
                 collaboration_args.statistics_expiration,
                 collaboration_args.statistics_expiration,

+ 16 - 19
examples/albert/run_training_monitor.py

@@ -12,10 +12,11 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 
 import hivemind
 import hivemind
+from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
 import utils
 import utils
-from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -55,14 +56,14 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     upload_interval: Optional[float] = field(
     upload_interval: Optional[float] = field(
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
     )
-    store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
+    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 
 
 class CheckpointHandler:
 class CheckpointHandler:
     def __init__(
     def __init__(
         self,
         self,
         monitor_args: TrainingMonitorArguments,
         monitor_args: TrainingMonitorArguments,
-        collab_optimizer_args: CollaborativeOptimizerArguments,
+        optimizer_args: OptimizerArguments,
         averager_args: AveragerArguments,
         averager_args: AveragerArguments,
         dht: hivemind.DHT,
         dht: hivemind.DHT,
     ):
     ):
@@ -95,17 +96,13 @@ class CheckpointHandler:
             debias=True,
             debias=True,
         )
         )
 
 
-        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
-
-        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
-            opt=opt,
+        self.state_averager = TrainingStateAverager(
             dht=dht,
             dht=dht,
+            optimizer=opt,
             prefix=experiment_prefix,
             prefix=experiment_prefix,
-            compression_type=hivemind.Float16Compression(),
-            bandwidth=collab_optimizer_args.bandwidth,
-            target_batch_size=adjusted_target_batch_size,
-            client_mode=collab_optimizer_args.client_mode,
-            verbose=True,
+            state_compression=hivemind.Float16Compression(),
+            bandwidth=optimizer_args.bandwidth,
+            client_mode=optimizer_args.client_mode,
             start=True,
             start=True,
             **asdict(averager_args),
             **asdict(averager_args),
         )
         )
@@ -121,7 +118,7 @@ class CheckpointHandler:
 
 
     def save_state(self, cur_step):
     def save_state(self, cur_step):
         logger.info("Saving state from peers")
         logger.info("Saving state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.state_averager.load_state_from_peers()
         self.previous_step = cur_step
         self.previous_step = cur_step
 
 
     def is_time_to_upload(self):
     def is_time_to_upload(self):
@@ -134,7 +131,7 @@ class CheckpointHandler:
 
 
     def upload_checkpoint(self, current_loss):
     def upload_checkpoint(self, current_loss):
         logger.info("Saving optimizer")
         logger.info("Saving optimizer")
-        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
         logger.info("Started uploading to Model Hub")
         self.model.push_to_hub(
         self.model.push_to_hub(
@@ -146,8 +143,8 @@ class CheckpointHandler:
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
-    monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
+    monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
 
     if monitor_args.use_google_dns:
     if monitor_args.use_google_dns:
         request = requests.get("https://api.ipify.org")
         request = requests.get("https://api.ipify.org")
@@ -176,8 +173,8 @@ if __name__ == "__main__":
         wandb.init(project=monitor_args.wandb_project)
         wandb.init(project=monitor_args.wandb_project)
 
 
     current_step = 0
     current_step = 0
-    if monitor_args.store_checkpoins:
-        checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
+    if monitor_args.store_checkpoints:
+        checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
 
     while True:
     while True:
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
@@ -219,7 +216,7 @@ if __name__ == "__main__":
                         }
                         }
                     )
                     )
 
 
-                if monitor_args.store_checkpoins:
+                if monitor_args.store_checkpoints:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():
                         if checkpoint_handler.is_time_to_upload():

+ 1 - 1
hivemind/__init__.py

@@ -23,4 +23,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = "1.0.0dev0"
+__version__ = "1.1.0dev0"

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,7 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
+from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager
 from hivemind.optim.training_averager import TrainingAverager

+ 8 - 0
hivemind/optim/base.py

@@ -1,3 +1,5 @@
+from warnings import warn
+
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
@@ -8,6 +10,12 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
         self.opt, self.dht = opt, dht
         self.opt, self.dht = opt, dht
+        warn(
+            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
+            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
+            FutureWarning,
+            stacklevel=2,
+        )
 
 
     @property
     @property
     def state(self):
     def state(self):

+ 5 - 5
hivemind/optim/collaborative.py

@@ -57,15 +57,15 @@ class TrainingProgressSchema(BaseModel):
 
 
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
     """
-    :note: **For new projects please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
-      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and a many advanced ones.
-      CollaborativeOptimizer will still be supported for a while, but it will be deprecated eventually.
-
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
+    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
 
 
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
     Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
     Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
 
 
+    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
+      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
+      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
+
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
 
       * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
       * calling .step will periodically zero-out gradients w.r.t. model parameters after each step

+ 0 - 0
hivemind/optim/experimental/__init__.py


+ 0 - 0
hivemind/optim/experimental/grad_averager.py → hivemind/optim/grad_averager.py


+ 15 - 8
hivemind/optim/experimental/optimizer.py → hivemind/optim/optimizer.py

@@ -11,9 +11,10 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.optim.experimental.grad_averager import GradientAverager
-from hivemind.optim.experimental.progress_tracker import ProgressTracker
-from hivemind.optim.experimental.state_averager import (
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
+from hivemind.optim.state_averager import (
     LRSchedulerBase,
     LRSchedulerBase,
     OptimizerFactory,
     OptimizerFactory,
     Parameters,
     Parameters,
@@ -22,7 +23,6 @@ from hivemind.optim.experimental.state_averager import (
     TorchOptimizer,
     TorchOptimizer,
     TrainingStateAverager,
     TrainingStateAverager,
 )
 )
-from hivemind.optim.grad_scaler import GradScaler
 from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
 from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -154,7 +154,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
-    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
+    :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
 
 
     :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
     :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
@@ -345,6 +345,10 @@ class Optimizer(torch.optim.Optimizer):
         """
         """
         return self.state_averager.local_epoch
         return self.state_averager.local_epoch
 
 
+    @property
+    def local_progress(self) -> LocalTrainingProgress:
+        return self.tracker.local_progress
+
     @property
     @property
     def use_local_updates(self) -> bool:
     def use_local_updates(self) -> bool:
         return self.grad_averager is None
         return self.grad_averager is None
@@ -384,7 +388,7 @@ class Optimizer(torch.optim.Optimizer):
             with torch.enable_grad():
             with torch.enable_grad():
                 loss = closure()
                 loss = closure()
 
 
-        if not self.auxiliary and self.should_load_state_from_peers():
+        if not self.auxiliary and self._should_load_state_from_peers():
             logger.log(self.status_loglevel, "Peer is out of sync")
             logger.log(self.status_loglevel, "Peer is out of sync")
             self.load_state_from_peers()
             self.load_state_from_peers()
             return loss  # local gradients were computed with out-of-sync parameters, must start over
             return loss  # local gradients were computed with out-of-sync parameters, must start over
@@ -564,7 +568,6 @@ class Optimizer(torch.optim.Optimizer):
 
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
         if eta_seconds_to_averaging <= self.matchmaking_time:
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
-
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
@@ -626,7 +629,7 @@ class Optimizer(torch.optim.Optimizer):
                 else:
                 else:
                     param.grad.zero_()
                     param.grad.zero_()
 
 
-    def should_load_state_from_peers(self) -> bool:
+    def _should_load_state_from_peers(self) -> bool:
         """
         """
         If true, peer will discard local progress and attempt to download state from peers.
         If true, peer will discard local progress and attempt to download state from peers.
         This method allows peer to continue training in two cases:
         This method allows peer to continue training in two cases:
@@ -646,6 +649,10 @@ class Optimizer(torch.optim.Optimizer):
             return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
             return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
         return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
         return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
 
+    def is_synchronized_with_peers(self) -> bool:
+        """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
+        return self.local_epoch >= self.tracker.global_epoch - 1
+
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """
         """
         Attempt to load the newest collaboration state from other peers within the same run_id.
         Attempt to load the newest collaboration state from other peers within the same run_id.

+ 0 - 0
hivemind/optim/experimental/progress_tracker.py → hivemind/optim/progress_tracker.py


+ 0 - 0
hivemind/optim/experimental/state_averager.py → hivemind/optim/state_averager.py


+ 4 - 4
tests/test_optimizer.py

@@ -11,10 +11,10 @@ import torch.nn.functional as F
 
 
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.experimental.grad_averager import GradientAverager
-from hivemind.optim.experimental.optimizer import Optimizer
-from hivemind.optim.experimental.progress_tracker import ProgressTracker
-from hivemind.optim.experimental.state_averager import TrainingStateAverager
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.progress_tracker import ProgressTracker
+from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey