Эх сурвалжийг харах

tweak default parameters based on 128peers test

justheuristic 3 жил өмнө
parent
commit
0edeea35e0

+ 7 - 4
hivemind/optim/experimental/optimizer.py

@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
         averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
         offload_optimizer: Optional[bool] = None,
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode = client_mode if client_mode is None else dht.client_mode
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
 
 
-        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
-        self.shutdown_timeout = shutdown_timeout
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
 
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
         self.scheduled_grads: Optional[StepControl] = None
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             prefix=f"{self.run_id}_state_averager",
             min_matchmaking_time=self.matchmaking_time,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
             min_matchmaking_time=self.matchmaking_time,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             auxiliary=self.auxiliary,

+ 3 - 3
hivemind/optim/experimental/progress_tracker.py

@@ -83,12 +83,12 @@ class ProgressTracker(threading.Thread):
         *,
         *,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 10,
+        max_refresh_period: float = 30,
         default_refresh_period: float = 3,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         expected_drift_rate: float = 0.2,
         performance_ema_alpha: float = 0.1,
         performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 30.0,
+        metadata_expiration: float = 60.0,
         status_loglevel: int = logging.DEBUG,
         status_loglevel: int = logging.DEBUG,
         private_key: Optional[RSAPrivateKey] = None,
         private_key: Optional[RSAPrivateKey] = None,
         daemon: bool = True,
         daemon: bool = True,
@@ -198,7 +198,7 @@ class ProgressTracker(threading.Thread):
         store_task = None
         store_task = None
         try:
         try:
             while not self.shutdown_triggered.is_set():
             while not self.shutdown_triggered.is_set():
-                wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 if self.should_report_progress.is_set():
                 if self.should_report_progress.is_set():