Jelajahi Sumber

Minor style updates in examples (#321)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 tahun lalu
induk
melakukan
def7038401

+ 1 - 1
examples/albert/arguments.py

@@ -102,7 +102,7 @@ class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArgume
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
     backup_every_steps: int = field(
-        default=10, metadata={"help": "In case of NaN, training restore from a backup updated with this frequency."}
+        default=10, metadata={"help": "Frequency of backups to restore from in case of encountering NaN values"}
     )
 
 

+ 7 - 7
examples/albert/run_trainer.py

@@ -5,7 +5,6 @@ import os
 import pickle
 from dataclasses import asdict
 from pathlib import Path
-from typing import Any
 
 import torch
 import transformers
@@ -97,8 +96,8 @@ def get_optimizer_and_scheduler(training_args, model):
 
 class CollaborativeCallback(transformers.TrainerCallback):
     """
-    This callback monitors and reports collaborative training progress,
-    In case of a catastrophic failure, it can also revert training to a backup
+    This callback monitors and reports collaborative training progress.
+    In case of a catastrophic failure, it can also revert training to a backup.
     """
 
     def __init__(
@@ -153,6 +152,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 )
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
                     logger.info(f"Local loss: {self.loss / self.steps}")
                 if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
@@ -181,16 +181,16 @@ class CollaborativeCallback(transformers.TrainerCallback):
         return True
 
     @torch.no_grad()
-    def backup_state(self) -> Any:
+    def backup_state(self) -> bytes:
         return pickle.dumps(
-            {"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
+            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
         )
 
     @torch.no_grad()
-    def restore_from_backup(self, backup):
+    def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["training"])
+        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
 
 
 class NoOpScheduler(LRSchedulerBase):

+ 1 - 2
examples/albert/utils.py

@@ -42,8 +42,7 @@ def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
     else:
-        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr]
-        available_ips += [Multiaddr(addr) for addr in visible_maddrs if "ip6" in addr]
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
         if available_ips:
             preferred_ip = choose_ip_address(available_ips)
             selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]

+ 1 - 0
hivemind/optim/simple.py

@@ -79,6 +79,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     def step(self, *args, **kwargs):
         with self.lock_parameters:
             loss = self.opt.step(*args, **kwargs)
+
         self.local_step += 1
         if self.local_step % self.averaging_step_period == 0:
             self.update_event.set()

+ 2 - 2
hivemind/p2p/p2p_daemon.py

@@ -297,11 +297,11 @@ class P2P:
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: type,
-        max_prefetch: int = 0,
+        max_prefetch: int = 5,
     ) -> None:
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
-          ``max_prefetch <= 0`` means unlimited (default).
+          ``max_prefetch <= 0`` means unlimited.
 
         :note:  Since the cancel messages are sent via the input stream,
           they will not be received while the prefetch buffer is full.