|
@@ -48,8 +48,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
|
|
|
- average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
|
|
|
- total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
|
|
|
+ average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int = 1,
|
|
|
+ total_steps_in_epoch: int = 1000, average_opt_statistics: Sequence[str] = (),
|
|
|
scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
|
|
|
report_progress_expiration: int = 30, timeout: Optional[float] = None,
|
|
|
verbose: bool = False, **kwargs):
|
|
@@ -63,11 +63,9 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
average_opt_statistics=average_opt_statistics,
|
|
|
dht=dht, start=True, prefix=prefix,
|
|
|
target_group_size=target_group_size, **kwargs)
|
|
|
- self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
|
|
|
-
|
|
|
- if scheduler_cls:
|
|
|
- self.scheduler = scheduler_cls(opt)
|
|
|
|
|
|
+ self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
|
|
|
+ self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
|
|
|
self.local_epoch = 0
|
|
|
self.report_progress_expiration = report_progress_expiration
|
|
|
self.max_allowed_epoch_difference = max_allowed_epoch_difference
|
|
@@ -76,8 +74,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
|
|
|
self.lock_scheduler_params = Lock()
|
|
|
self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
|
|
|
- self.fetch_decentralized_state_event.set()
|
|
|
- self._fetch_decentralized_state(initial=True)
|
|
|
+ self._fetch_decentralized_state()
|
|
|
|
|
|
self.background_averaging_thread = Thread(
|
|
|
name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
|
|
@@ -87,7 +84,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
|
|
|
self.background_report_progress.start()
|
|
|
self.background_fetch_decentralized_state = Thread(
|
|
|
- name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
|
|
|
+ name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state_periodically)
|
|
|
self.background_fetch_decentralized_state.start()
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
@@ -188,39 +185,36 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
expiration_time=current_time + self.report_progress_expiration, return_future=False)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- def _fetch_decentralized_state(self, initial: bool = False):
|
|
|
- """ Read collaboration state reported by peers """
|
|
|
+ def _fetch_decentralized_state_periodically(self):
|
|
|
+ """ Read decentralized state loop """
|
|
|
while not self.stop_event.is_set():
|
|
|
self.fetch_decentralized_state_event.wait()
|
|
|
self.fetch_decentralized_state_event.clear()
|
|
|
if self.stop_event.is_set():
|
|
|
break
|
|
|
- response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
|
|
|
- if not isinstance(response, dict) or len(response) == 0:
|
|
|
- logger.info(f"Found no active peers: {response}")
|
|
|
- with self.lock_scheduler_params:
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
- if initial:
|
|
|
- break
|
|
|
- continue
|
|
|
-
|
|
|
- valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
|
|
|
- num_peers = len(valid_peer_states)
|
|
|
+ self._fetch_decentralized_state()
|
|
|
|
|
|
+ @torch.no_grad()
|
|
|
+ def _fetch_decentralized_state(self):
|
|
|
+ """ Read decentralized state reported by peers """
|
|
|
+ response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
|
|
|
+ if not isinstance(response, dict) or len(response) == 0:
|
|
|
+ logger.info(f"Found no active peers: {response}")
|
|
|
with self.lock_scheduler_params:
|
|
|
- global_epoch = self.local_epoch
|
|
|
- for step, time, epoch in valid_peer_states:
|
|
|
- global_epoch = max(global_epoch, epoch)
|
|
|
-
|
|
|
- total_steps = 0
|
|
|
- for step, time, epoch in valid_peer_states:
|
|
|
- if epoch == global_epoch:
|
|
|
- total_steps += step
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
+ return
|
|
|
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
-
|
|
|
- if initial:
|
|
|
- break
|
|
|
+ valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
|
|
|
+ num_peers = len(valid_peer_states)
|
|
|
+ with self.lock_scheduler_params:
|
|
|
+ global_epoch = self.local_epoch
|
|
|
+ for step, time, epoch in valid_peer_states:
|
|
|
+ global_epoch = max(global_epoch, epoch)
|
|
|
+ total_steps = 0
|
|
|
+ for step, time, epoch in valid_peer_states:
|
|
|
+ if epoch == global_epoch:
|
|
|
+ total_steps += step
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
|
|
|
|
|
|
class DecentralizedSGD(DecentralizedOptimizer):
|