Browse Source

remove the concept of step_tolerance

justheuristic 3 năm trước cách đây
mục cha
commit
fb9b0d7167
2 tập tin đã thay đổi với 4 bổ sung6 xóa
  1. 1 1
      examples/albert/arguments.py
  2. 3 5
      hivemind/optim/collaborative.py

+ 1 - 1
examples/albert/arguments.py

@@ -49,7 +49,7 @@ class AveragerArguments:
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}

+ 3 - 5
hivemind/optim/collaborative.py

@@ -80,7 +80,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
       refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
     :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
-    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
     :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
@@ -116,7 +115,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
-        step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = False,
@@ -143,7 +141,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.load_state_timeout = load_state_timeout
         self.metadata_expiration = metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
-        self.client_mode, self.step_tolerance = client_mode, step_tolerance
+        self.client_mode = client_mode
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
@@ -181,7 +179,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
+        return self.local_step >= self.collaboration_state.optimizer_step
 
     def is_alive(self) -> bool:
         return self.averager.is_alive()
@@ -448,7 +446,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         for state in valid_peer_states:
             total_samples_per_second += state.samples_per_second
-            if state.step == self.local_step:
+            if state.step >= global_optimizer_step:
                 total_samples_accumulated += state.samples_accumulated
                 estimated_current_samples += (
                     state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second