فهرست منبع

fix some issues in pr

xtinkt 4 سال پیش
والد
کامیت
937e622b8b
1فایلهای تغییر یافته به همراه28 افزوده شده و 34 حذف شده
  1. 28 34
      hivemind/optim/averaged.py

+ 28 - 34
hivemind/optim/averaged.py

@@ -48,8 +48,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
-                 average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
-                 total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
+                 average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int = 1,
+                 total_steps_in_epoch: int = 1000, average_opt_statistics: Sequence[str] = (),
                  scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
                  report_progress_expiration: int = 30, timeout: Optional[float] = None,
                  verbose: bool = False, **kwargs):
@@ -63,11 +63,9 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                                          average_opt_statistics=average_opt_statistics,
                                          dht=dht, start=True, prefix=prefix,
                                          target_group_size=target_group_size, **kwargs)
-        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-
-        if scheduler_cls:
-            self.scheduler = scheduler_cls(opt)
 
+        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
+        self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
         self.local_epoch = 0
         self.report_progress_expiration = report_progress_expiration
         self.max_allowed_epoch_difference = max_allowed_epoch_difference
@@ -76,8 +74,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
         self.lock_scheduler_params = Lock()
         self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
-        self.fetch_decentralized_state_event.set()
-        self._fetch_decentralized_state(initial=True)
+        self._fetch_decentralized_state()
 
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
@@ -87,7 +84,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
         self.background_report_progress.start()
         self.background_fetch_decentralized_state = Thread(
-            name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
+            name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state_periodically)
         self.background_fetch_decentralized_state.start()
 
     def step(self, *args, **kwargs):
@@ -188,39 +185,36 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                            expiration_time=current_time + self.report_progress_expiration, return_future=False)
 
     @torch.no_grad()
-    def _fetch_decentralized_state(self, initial: bool = False):
-        """ Read collaboration state reported by peers """
+    def _fetch_decentralized_state_periodically(self):
+        """ Read decentralized state loop """
         while not self.stop_event.is_set():
             self.fetch_decentralized_state_event.wait()
             self.fetch_decentralized_state_event.clear()
             if self.stop_event.is_set():
                 break
-            response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
-            if not isinstance(response, dict) or len(response) == 0:
-                logger.info(f"Found no active peers: {response}")
-                with self.lock_scheduler_params:
-                    self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
-                    if initial:
-                        break
-                    continue
-
-            valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
-            num_peers = len(valid_peer_states)
+            self._fetch_decentralized_state()
 
+    @torch.no_grad()
+    def _fetch_decentralized_state(self):
+        """ Read decentralized state reported by peers """
+        response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
+        if not isinstance(response, dict) or len(response) == 0:
+            logger.info(f"Found no active peers: {response}")
             with self.lock_scheduler_params:
-                global_epoch = self.local_epoch
-                for step, time, epoch in valid_peer_states:
-                    global_epoch = max(global_epoch, epoch)
-
-                total_steps = 0
-                for step, time, epoch in valid_peer_states:
-                    if epoch == global_epoch:
-                        total_steps += step
+                self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
+                return
 
-                self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
-
-                if initial:
-                    break
+        valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
+        num_peers = len(valid_peer_states)
+        with self.lock_scheduler_params:
+            global_epoch = self.local_epoch
+            for step, time, epoch in valid_peer_states:
+                global_epoch = max(global_epoch, epoch)
+            total_steps = 0
+            for step, time, epoch in valid_peer_states:
+                if epoch == global_epoch:
+                    total_steps += step
+            self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
 
 
 class DecentralizedSGD(DecentralizedOptimizer):