浏览代码

ignore_stale_updates

justheuristic 3 年之前
父节点
当前提交
4dddc75d16
共有 1 个文件被更改,包括 7 次插入2 次删除
  1. 7 2
      hivemind/optim/collaborative.py

+ 7 - 2
hivemind/optim/collaborative.py

@@ -81,6 +81,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
       refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
       refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
     :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
     :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
     :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
     :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
+    :param staleness_timeout: peers that reported gradients this many seconds ago or earlier do not count
+      toward progress for the current step (but do count toward other statistics, such as the collaboraiton size)
     :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
     :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
@@ -116,6 +118,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         metadata_expiration: float = 60.0,
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
         averaging_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
+        staleness_timeout: float = 30.0,
         step_tolerance: int = 1,
         step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         accumulate_grads_on: Optional[torch.device] = None,
@@ -139,6 +142,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             default_refresh_period,
             default_refresh_period,
         )
         )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.staleness_timeout = staleness_timeout
         self.averaging_timeout = averaging_timeout
         self.averaging_timeout = averaging_timeout
         self.load_state_timeout = load_state_timeout
         self.load_state_timeout = load_state_timeout
         self.metadata_expiration = metadata_expiration
         self.metadata_expiration = metadata_expiration
@@ -447,10 +451,11 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         for state in valid_peer_states:
         for state in valid_peer_states:
             total_samples_per_second += state.samples_per_second
             total_samples_per_second += state.samples_per_second
-            if state.step >= global_optimizer_step - self.step_tolerance:
+            delta_time = current_time - state.time
+            if state.step >= global_optimizer_step - self.step_tolerance and delta_time > self.staleness_timeout:
                 total_samples_accumulated += state.samples_accumulated
                 total_samples_accumulated += state.samples_accumulated
                 estimated_current_samples += (
                 estimated_current_samples += (
-                    state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
+                    state.samples_accumulated + max(0, delta_time) * state.samples_per_second
                 )
                 )
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.