|
@@ -81,8 +81,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
|
|
refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
|
|
:param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
|
|
:param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
|
|
:param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
|
|
:param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
|
|
- :param staleness_timeout: peers that reported gradients this many seconds ago or earlier do not count
|
|
|
|
- toward progress for the current step (but do count toward other statistics, such as the collaboraiton size)
|
|
|
|
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
|
|
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
@@ -118,7 +116,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
metadata_expiration: float = 60.0,
|
|
metadata_expiration: float = 60.0,
|
|
averaging_timeout: Optional[float] = None,
|
|
averaging_timeout: Optional[float] = None,
|
|
load_state_timeout: float = 600.0,
|
|
load_state_timeout: float = 600.0,
|
|
- staleness_timeout: float = 30.0,
|
|
|
|
step_tolerance: int = 1,
|
|
step_tolerance: int = 1,
|
|
reuse_grad_buffers: bool = False,
|
|
reuse_grad_buffers: bool = False,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
@@ -142,7 +139,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
default_refresh_period,
|
|
default_refresh_period,
|
|
)
|
|
)
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
|
|
- self.staleness_timeout = staleness_timeout
|
|
|
|
self.averaging_timeout = averaging_timeout
|
|
self.averaging_timeout = averaging_timeout
|
|
self.load_state_timeout = load_state_timeout
|
|
self.load_state_timeout = load_state_timeout
|
|
self.metadata_expiration = metadata_expiration
|
|
self.metadata_expiration = metadata_expiration
|
|
@@ -450,9 +446,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
|
|
total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
|
|
|
|
|
|
for state in valid_peer_states:
|
|
for state in valid_peer_states:
|
|
- if current_time - state.time > self.staleness_timeout:
|
|
|
|
- logger.debug(f"Ignoring record {state} because it is too old: {current_time - state.time} seconds.")
|
|
|
|
- continue
|
|
|
|
total_samples_per_second += state.samples_per_second
|
|
total_samples_per_second += state.samples_per_second
|
|
if state.step >= global_optimizer_step - self.step_tolerance:
|
|
if state.step >= global_optimizer_step - self.step_tolerance:
|
|
total_samples_accumulated += state.samples_accumulated
|
|
total_samples_accumulated += state.samples_accumulated
|