Browse Source

Fix server hanging in certain cases when connection is lost (#247)

* allow DecentralizedAverager to recover from stream cancels 
* add explicit logging for errors such as CancelledError
justheuristic 4 years ago
parent
commit
ddb5389e66

+ 1 - 1
examples/albert/run_trainer.py

@@ -38,7 +38,7 @@ class CollaborationArguments:
     trainer_uuid: str = uuid.uuid4().hex  # this peer's name - used when publishing metadata to DHT, default = random
 
     # optional tweaks
-    target_group_size: int = 64  # maximum group size for all-reduce
+    target_group_size: int = 256  # maximum group size for all-reduce
     metadata_expiration: float = 30  # peer's metadata will be removed if not updated in this many seconds
     statistics_expiration: float = 600  # statistics will be removed if not updated in this many seconds
     dht_listen_on: str = '[::]:*'  # network interface used for incoming DHT communication. Default: all ipv6

+ 9 - 6
hivemind/client/averaging/__init__.py

@@ -257,21 +257,24 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 # averaging is finished, exit the loop
                 future.set_result(allreduce_runner.gathered)
 
-            except (AllreduceException, MatchmakingException, AssertionError,
-                    asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
+            except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
+                    asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    logger.warning(f"Averager caught {e}")
-                    future.set_result(None)
+                    logger.exception(f"Averager caught {repr(e)}")
+                    future.set_exception(e)
                 else:
-                    logger.warning(f"Averager caught {e}, retrying")
+                    logger.warning(f"Averager caught {repr(e)}, retrying")
 
-            except Exception as e:
+            except BaseException as e:
                 future.set_exception(e)
                 raise
             finally:
                 _ = self._running_groups.pop(group_id, None)
                 self._pending_group_assembled.set()
+                if not future.done():
+                    future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
+                                                      " Please report this to hivemind issues."))
 
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """

+ 3 - 3
hivemind/optim/collaborative.py

@@ -83,7 +83,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
                  min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
                  expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
-                 metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
+                 metadata_expiration: float = 60.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
                  reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
                  client_mode: bool = False, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
@@ -193,8 +193,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-                except Exception as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {e}.")
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
             else:
                 logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "