瀏覽代碼

Improve hivemind.optim.experimental and averager stability (#421)

Change default timings based on training with 70 peers (north EU + west US)

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov 3 年之前
父節點
當前提交
a904cfd58b

+ 5 - 2
hivemind/averaging/allreduce.py

@@ -275,8 +275,11 @@ class AllReduceRunner(ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        try:
+            error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
+            await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 4 - 1
hivemind/averaging/averager.py

@@ -31,6 +31,7 @@ from hivemind.compression import (
 )
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
@@ -470,6 +471,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
                 ) as e:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                         if not step.cancelled():
                         if not step.cancelled():
@@ -535,7 +538,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
-                            self._state_updated.set()
+                        self._state_updated.set()
 
 
                     else:
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
                         async for _ in allreduce:  # trigger all-reduce by iterating

+ 8 - 2
hivemind/averaging/matchmaking.py

@@ -227,7 +227,10 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         self.current_leader = None
-                        await stream.aclose()
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
                         return await self._request_join_group(suggested_leader)
                         return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 logger.debug(f"{self} - leader disbanded group")
                 return None
                 return None
@@ -245,7 +248,10 @@ class Matchmaking:
             self.was_accepted_to_group.clear()
             self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None
             if stream is not None:
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
 
 
     def get_request_expiration_time(self) -> float:
     def get_request_expiration_time(self) -> float:
         """Returns the averager's current expiration time, which is used to send join requests to leaders"""
         """Returns the averager's current expiration time, which is used to send join requests to leaders"""

+ 35 - 27
hivemind/optim/experimental/optimizer.py

@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
       Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
         averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
         offload_optimizer: Optional[bool] = None,
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode = client_mode if client_mode is None else dht.client_mode
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
 
 
-        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
-        self.shutdown_timeout = shutdown_timeout
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
 
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
         self.scheduled_grads: Optional[StepControl] = None
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             prefix=f"{self.run_id}_state_averager",
             min_matchmaking_time=self.matchmaking_time,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
             min_matchmaking_time=self.matchmaking_time,
             min_matchmaking_time=self.matchmaking_time,
-            allreduce_timeout=self.averaging_timeout,
+            allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             auxiliary=self.auxiliary,
@@ -508,8 +511,8 @@ class Optimizer(torch.optim.Optimizer):
                 logger.exception(e)
                 logger.exception(e)
 
 
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
-            logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
-            self.scheduled_grads.cancel()
+            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+            self._tag_along_with_zero_weight(self.scheduled_grads)
             self.scheduled_grads = None
             self.scheduled_grads = None
         return began_averaging_gradients
         return began_averaging_gradients
 
 
@@ -542,6 +545,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        return
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
             return  # averaging is not performed at this epoch
@@ -643,7 +647,11 @@ class Optimizer(torch.optim.Optimizer):
 
 
         If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
         If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
         """
         """
-        self._finish_background_averaging()
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
         self.state_averager.step(wait_for_delayed_updates=True)
         self.state_averager.step(wait_for_delayed_updates=True)
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
@@ -671,25 +679,6 @@ class Optimizer(torch.optim.Optimizer):
                 if not self.client_mode:
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
 
-    def _finish_background_averaging(self):
-        for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None:
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
-                if not scheduled_round.triggered:
-                    scheduled_round.weight = 0
-                    scheduled_round.allow_allreduce()
-        for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None and not scheduled_round.done():
-                try:
-                    time_to_deadline = scheduled_round.deadline - get_dht_time()
-                    scheduled_round.result(timeout=max(0.0, time_to_deadline))
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
-                if not scheduled_round.done():
-                    scheduled_round.cancel()
-        self.scheduled_grads = self.scheduled_state = None
-
     def state_dict(self) -> dict:
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict["state"]["local_epoch"] = self.local_epoch
         state_dict["state"]["local_epoch"] = self.local_epoch
@@ -729,11 +718,30 @@ class Optimizer(torch.optim.Optimizer):
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
         return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
 
 
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
     def shutdown(self):
     def shutdown(self):
         logger.log(self.status_loglevel, "Sending goodbye to peers...")
         logger.log(self.status_loglevel, "Sending goodbye to peers...")
         self.tracker.shutdown(self.shutdown_timeout)
         self.tracker.shutdown(self.shutdown_timeout)
         self.state_averager.step(wait_for_delayed_updates=True)
         self.state_averager.step(wait_for_delayed_updates=True)
-        self._finish_background_averaging()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
+
         logger.log(self.status_loglevel, "Shutting down averagers...")
         logger.log(self.status_loglevel, "Shutting down averagers...")
         self.state_averager.shutdown()
         self.state_averager.shutdown()
         if self.use_gradient_averaging:
         if self.use_gradient_averaging:

+ 3 - 3
hivemind/optim/experimental/progress_tracker.py

@@ -83,12 +83,12 @@ class ProgressTracker(threading.Thread):
         *,
         *,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 10,
+        max_refresh_period: float = 30,
         default_refresh_period: float = 3,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         expected_drift_rate: float = 0.2,
         performance_ema_alpha: float = 0.1,
         performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 30.0,
+        metadata_expiration: float = 60.0,
         status_loglevel: int = logging.DEBUG,
         status_loglevel: int = logging.DEBUG,
         private_key: Optional[RSAPrivateKey] = None,
         private_key: Optional[RSAPrivateKey] = None,
         daemon: bool = True,
         daemon: bool = True,
@@ -198,7 +198,7 @@ class ProgressTracker(threading.Thread):
         store_task = None
         store_task = None
         try:
         try:
             while not self.shutdown_triggered.is_set():
             while not self.shutdown_triggered.is_set():
-                wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
                 if self.should_report_progress.is_set():
                 if self.should_report_progress.is_set():