소스 검색

Added info about new params. Changed some names.

xtinkt 4 년 전
부모
커밋
97cec29e13
2개의 변경된 파일22개의 추가작업 그리고 19개의 파일을 삭제
  1. 1 1
      hivemind/optim/__init__.py
  2. 21 18
      hivemind/optim/decentralized_optimizers.py

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,5 +1,5 @@
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.decentralized_optimizers import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer

+ 21 - 18
hivemind/optim/decentralized_optimizers.py

@@ -12,7 +12,6 @@ from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 
 logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
 @dataclass(frozen=False)
@@ -27,15 +26,21 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
 
     :param opt: a pytorch optimizer configured to update model parameters.
     :param dht: a running hivemind DHT daemon connected to other peers
+    :param prefix: all DHT keys that point to optimization metadata will have this prefix
+    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param average_parameters: whether to average model parameters
     :param average_gradients: whether to average gradients
+    :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
+      from other peer.
+    :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
     :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
+    :param scheduler_cls: lambda with opt in argument which returns learning rate scheduler
     :param averaging_steps_period: performs averaging after this many optimizer steps
     :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
       many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
+    report_progress_expiration
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
+    :param verbose: verbose info
     :param kwargs: additional parameters passed to TrainingAverager
     :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
       necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
@@ -45,8 +50,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
                  average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
                  total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
-                 scheduler: Optional[LRSchedulerBase] = None,
-                 averaging_steps_period: int = 1, averaging_time_period: float = 0,
+                 scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
                  report_progress_expiration: int = 30, timeout: Optional[float] = None,
                  verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
@@ -61,9 +65,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                                          target_group_size=target_group_size, **kwargs)
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
 
-        self.scheduler = scheduler
+        if scheduler_cls:
+            self.scheduler = scheduler_cls(opt)
 
-        self.epoch = 0
+        self.local_epoch = 0
         self.report_progress_expiration = report_progress_expiration
         self.max_allowed_epoch_difference = max_allowed_epoch_difference
         self.total_steps_in_epoch = total_steps_in_epoch
@@ -85,8 +90,6 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
         self.background_fetch_decentralized_state.start()
 
-
-
     def step(self, *args, **kwargs):
         if not self.is_synchronized:
             logger.warning("Peer is out of sync.")
@@ -94,16 +97,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             return
 
         with self.lock_scheduler_params:
-            if self.epoch < self.decentralized_state.max_epoch:
+            if self.local_epoch < self.decentralized_state.max_epoch:
                 self.local_step = 0
-                self.epoch = self.decentralized_state.max_epoch
+                self.local_epoch = self.decentralized_state.max_epoch
 
             if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
-                self.epoch += 1
+                self.local_epoch += 1
                 self.local_step = 0
 
             if self.scheduler:
-                while self.epoch > self.scheduler._step_count:
+                while self.local_epoch > self.scheduler._step_count:
                     self.scheduler.step()
 
         with self.lock_parameters:
@@ -135,7 +138,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             self.zero_grad()
             self.averager.load_state_from_peers(**kwargs)
             self.local_step = 0
-            self.epoch = self.decentralized_state.max_epoch
+            self.local_epoch = self.decentralized_state.max_epoch
 
     @staticmethod
     @torch.no_grad()
@@ -169,7 +172,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
-        return self.epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
+        return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
 
     @torch.no_grad()
     def _report_progress(self):
@@ -180,7 +183,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                 break
             current_time = get_dht_time()
             with self.lock_scheduler_params:
-                local_state_info = [self.local_step, current_time, self.epoch]
+                local_state_info = [self.local_step, current_time, self.local_epoch]
             self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
                            expiration_time=current_time + self.report_progress_expiration, return_future=False)
 
@@ -196,7 +199,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             if not isinstance(response, dict) or len(response) == 0:
                 logger.info(f"Found no active peers: {response}")
                 with self.lock_scheduler_params:
-                    self.decentralized_state = DecentralizedState(max_epoch=self.epoch, total_steps=self.local_step)
+                    self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
                     if initial:
                         break
                     continue
@@ -205,7 +208,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             num_peers = len(valid_peer_states)
 
             with self.lock_scheduler_params:
-                global_epoch = self.epoch
+                global_epoch = self.local_epoch
                 for step, time, epoch in valid_peer_states:
                     global_epoch = max(global_epoch, epoch)