Sfoglia il codice sorgente

Improve hivemind.optim.experimental and averager stability (#421)

Change default timings based on training with 70 peers (north EU + west US)

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov 3 anni fa
parent
commit
a904cfd58b

+ 5 - 2
hivemind/averaging/allreduce.py

@@ -275,8 +275,11 @@ class AllReduceRunner(ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        try:
+            error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
+            await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 4 - 1
hivemind/averaging/averager.py

@@ -31,6 +31,7 @@ from hivemind.compression import (
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
@@ -470,6 +471,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                         if not step.cancelled():
@@ -535,7 +538,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                            self._state_updated.set()
+                        self._state_updated.set()
 
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating

+ 8 - 2
hivemind/averaging/matchmaking.py

@@ -227,7 +227,10 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
-                        await stream.aclose()
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
                         return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
@@ -245,7 +248,10 @@ class Matchmaking:
             self.was_accepted_to_group.clear()
             self.current_leader = None
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
 
     def get_request_expiration_time(self) -> float:
         """Returns the averager's current expiration time, which is used to send join requests to leaders"""

+ 35 - 27
hivemind/optim/experimental/optimizer.py

@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
 
-        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
-        self.shutdown_timeout = shutdown_timeout
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
@@ -508,8 +511,8 @@ class Optimizer(torch.optim.Optimizer):
                 logger.exception(e)
 
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
-            logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
-            self.scheduled_grads.cancel()
+            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+            self._tag_along_with_zero_weight(self.scheduled_grads)
             self.scheduled_grads = None
         return began_averaging_gradients
 
@@ -542,6 +545,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        return
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
@@ -643,7 +647,11 @@ class Optimizer(torch.optim.Optimizer):
 
         If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
         """
-        self._finish_background_averaging()
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
         self.state_averager.step(wait_for_delayed_updates=True)
 
         with self.tracker.pause_updates():
@@ -671,25 +679,6 @@ class Optimizer(torch.optim.Optimizer):
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
-    def _finish_background_averaging(self):
-        for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None:
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
-                if not scheduled_round.triggered:
-                    scheduled_round.weight = 0
-                    scheduled_round.allow_allreduce()
-        for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None and not scheduled_round.done():
-                try:
-                    time_to_deadline = scheduled_round.deadline - get_dht_time()
-                    scheduled_round.result(timeout=max(0.0, time_to_deadline))
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
-                if not scheduled_round.done():
-                    scheduled_round.cancel()
-        self.scheduled_grads = self.scheduled_state = None
-
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict["state"]["local_epoch"] = self.local_epoch
@@ -729,11 +718,30 @@ class Optimizer(torch.optim.Optimizer):
     def __repr__(self):
         return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
 
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
     def shutdown(self):
         logger.log(self.status_loglevel, "Sending goodbye to peers...")
         self.tracker.shutdown(self.shutdown_timeout)
         self.state_averager.step(wait_for_delayed_updates=True)
-        self._finish_background_averaging()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
+
         logger.log(self.status_loglevel, "Shutting down averagers...")
         self.state_averager.shutdown()
         if self.use_gradient_averaging:

+ 3 - 3
hivemind/optim/experimental/progress_tracker.py

@@ -83,12 +83,12 @@ class ProgressTracker(threading.Thread):
         *,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 10,
+        max_refresh_period: float = 30,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 30.0,
+        metadata_expiration: float = 60.0,
         status_loglevel: int = logging.DEBUG,
         private_key: Optional[RSAPrivateKey] = None,
         daemon: bool = True,
@@ -198,7 +198,7 @@ class ProgressTracker(threading.Thread):
         store_task = None
         try:
             while not self.shutdown_triggered.is_set():
-                wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 if self.should_report_progress.is_set():