|
@@ -80,7 +80,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
:note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
|
|
:note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
|
|
refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
|
|
refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
|
|
:param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
|
|
:param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
|
|
- :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
|
|
|
|
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
|
|
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param averaging_expiration: peer's requests for averaging will be valid for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
|
|
@@ -116,7 +115,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
metadata_expiration: float = 60.0,
|
|
metadata_expiration: float = 60.0,
|
|
averaging_timeout: Optional[float] = None,
|
|
averaging_timeout: Optional[float] = None,
|
|
load_state_timeout: float = 600.0,
|
|
load_state_timeout: float = 600.0,
|
|
- step_tolerance: int = 1,
|
|
|
|
reuse_grad_buffers: bool = False,
|
|
reuse_grad_buffers: bool = False,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
client_mode: bool = False,
|
|
client_mode: bool = False,
|
|
@@ -143,7 +141,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
self.load_state_timeout = load_state_timeout
|
|
self.load_state_timeout = load_state_timeout
|
|
self.metadata_expiration = metadata_expiration
|
|
self.metadata_expiration = metadata_expiration
|
|
self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
|
|
self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
|
|
- self.client_mode, self.step_tolerance = client_mode, step_tolerance
|
|
|
|
|
|
+ self.client_mode = client_mode
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
self.averager = self._make_averager(**kwargs)
|
|
self.averager = self._make_averager(**kwargs)
|
|
|
|
|
|
@@ -181,7 +179,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
@property
|
|
@property
|
|
def is_synchronized(self) -> bool:
|
|
def is_synchronized(self) -> bool:
|
|
- return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
|
|
|
|
|
|
+ return self.local_step >= self.collaboration_state.optimizer_step
|
|
|
|
|
|
def is_alive(self) -> bool:
|
|
def is_alive(self) -> bool:
|
|
return self.averager.is_alive()
|
|
return self.averager.is_alive()
|
|
@@ -448,7 +446,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
for state in valid_peer_states:
|
|
for state in valid_peer_states:
|
|
total_samples_per_second += state.samples_per_second
|
|
total_samples_per_second += state.samples_per_second
|
|
- if state.step == self.local_step:
|
|
|
|
|
|
+ if state.step >= global_optimizer_step:
|
|
total_samples_accumulated += state.samples_accumulated
|
|
total_samples_accumulated += state.samples_accumulated
|
|
estimated_current_samples += (
|
|
estimated_current_samples += (
|
|
state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
|
|
state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
|