瀏覽代碼

Minor style updates in examples (#321)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父節點
當前提交
def7038401
共有 5 個文件被更改,包括 12 次插入12 次删除
  1. 1 1
      examples/albert/arguments.py
  2. 7 7
      examples/albert/run_trainer.py
  3. 1 2
      examples/albert/utils.py
  4. 1 0
      hivemind/optim/simple.py
  5. 2 2
      hivemind/p2p/p2p_daemon.py

+ 1 - 1
examples/albert/arguments.py

@@ -102,7 +102,7 @@ class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArgume
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
     )
     backup_every_steps: int = field(
     backup_every_steps: int = field(
-        default=10, metadata={"help": "In case of NaN, training restore from a backup updated with this frequency."}
+        default=10, metadata={"help": "Frequency of backups to restore from in case of encountering NaN values"}
     )
     )
 
 
 
 

+ 7 - 7
examples/albert/run_trainer.py

@@ -5,7 +5,6 @@ import os
 import pickle
 import pickle
 from dataclasses import asdict
 from dataclasses import asdict
 from pathlib import Path
 from pathlib import Path
-from typing import Any
 
 
 import torch
 import torch
 import transformers
 import transformers
@@ -97,8 +96,8 @@ def get_optimizer_and_scheduler(training_args, model):
 
 
 class CollaborativeCallback(transformers.TrainerCallback):
 class CollaborativeCallback(transformers.TrainerCallback):
     """
     """
-    This callback monitors and reports collaborative training progress,
-    In case of a catastrophic failure, it can also revert training to a backup
+    This callback monitors and reports collaborative training progress.
+    In case of a catastrophic failure, it can also revert training to a backup.
     """
     """
 
 
     def __init__(
     def __init__(
@@ -153,6 +152,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 )
                 )
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
                 if self.steps:
                     logger.info(f"Local loss: {self.loss / self.steps}")
                     logger.info(f"Local loss: {self.loss / self.steps}")
                 if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
                 if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
@@ -181,16 +181,16 @@ class CollaborativeCallback(transformers.TrainerCallback):
         return True
         return True
 
 
     @torch.no_grad()
     @torch.no_grad()
-    def backup_state(self) -> Any:
+    def backup_state(self) -> bytes:
         return pickle.dumps(
         return pickle.dumps(
-            {"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
+            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
         )
         )
 
 
     @torch.no_grad()
     @torch.no_grad()
-    def restore_from_backup(self, backup):
+    def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["training"])
+        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
 
 
 
 
 class NoOpScheduler(LRSchedulerBase):
 class NoOpScheduler(LRSchedulerBase):

+ 1 - 2
examples/albert/utils.py

@@ -42,8 +42,7 @@ def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
     else:
     else:
-        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr]
-        available_ips += [Multiaddr(addr) for addr in visible_maddrs if "ip6" in addr]
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
         if available_ips:
         if available_ips:
             preferred_ip = choose_ip_address(available_ips)
             preferred_ip = choose_ip_address(available_ips)
             selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
             selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]

+ 1 - 0
hivemind/optim/simple.py

@@ -79,6 +79,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     def step(self, *args, **kwargs):
     def step(self, *args, **kwargs):
         with self.lock_parameters:
         with self.lock_parameters:
             loss = self.opt.step(*args, **kwargs)
             loss = self.opt.step(*args, **kwargs)
+
         self.local_step += 1
         self.local_step += 1
         if self.local_step % self.averaging_step_period == 0:
         if self.local_step % self.averaging_step_period == 0:
             self.update_event.set()
             self.update_event.set()

+ 2 - 2
hivemind/p2p/p2p_daemon.py

@@ -297,11 +297,11 @@ class P2P:
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: type,
         input_protobuf_type: type,
-        max_prefetch: int = 0,
+        max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
-          ``max_prefetch <= 0`` means unlimited (default).
+          ``max_prefetch <= 0`` means unlimited.
 
 
         :note:  Since the cancel messages are sent via the input stream,
         :note:  Since the cancel messages are sent via the input stream,
           they will not be received while the prefetch buffer is full.
           they will not be received while the prefetch buffer is full.