Преглед изворни кода

tweak default parameters based on 128peers test

justheuristic пре 3 година
родитељ
комит
0edeea35e0

+ 7 - 4
hivemind/optim/experimental/optimizer.py

@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
 
-        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
-        self.shutdown_timeout = shutdown_timeout
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,

+ 3 - 3
hivemind/optim/experimental/progress_tracker.py

@@ -83,12 +83,12 @@ class ProgressTracker(threading.Thread):
         *,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 10,
+        max_refresh_period: float = 30,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 30.0,
+        metadata_expiration: float = 60.0,
         status_loglevel: int = logging.DEBUG,
         private_key: Optional[RSAPrivateKey] = None,
         daemon: bool = True,
@@ -198,7 +198,7 @@ class ProgressTracker(threading.Thread):
         store_task = None
         try:
             while not self.shutdown_triggered.is_set():
-                wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 if self.should_report_progress.is_set():