Sfoglia il codice sorgente

Fix deadlocks in DecentralizedAverager and MPFuture (#331)

This PR does the following:

1. Fix a possible deadlock in DecentralizedAverager.rpc_join_group().
2. Fix a possible deadlock related to corrupted MPFuture state after killing child processes.
3. Add -v flag to pytest in CI.

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov 4 anni fa
parent
commit
0d6728475f

+ 3 - 3
.github/workflows/run-tests.yml

@@ -33,7 +33,7 @@ jobs:
       - name: Test
         run: |
           cd tests
-          pytest --durations=0 --durations-min=1.0
+          pytest --durations=0 --durations-min=1.0 -v
 
   build_and_test_p2pd:
     runs-on: ubuntu-latest
@@ -60,7 +60,7 @@ jobs:
       - name: Test
         run: |
           cd tests
-          pytest -k "p2p" 
+          pytest -k "p2p" -v
 
   codecov_in_develop_mode:
 
@@ -87,6 +87,6 @@ jobs:
           pip install -e .
       - name: Test
         run: |
-          pytest --cov=hivemind tests
+          pytest --cov=hivemind -v tests
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1

+ 4 - 0
hivemind/averaging/matchmaking.py

@@ -180,6 +180,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
+                        group_key=self.group_key_manager.current_key,
                     )
                 )
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
@@ -315,11 +316,14 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             or not isinstance(request.endpoint, Endpoint)
             or len(request.endpoint) == 0
             or self.client_mode
+            or not isinstance(request.group_key, GroupKey)
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif request.group_key != self.group_key_manager.current_key:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):

+ 2 - 0
hivemind/proto/averaging.proto

@@ -28,6 +28,7 @@ enum MessageCode {
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+  BAD_GROUP_KEY = 18;        // "I will not accept you. My current group key differs (maybe you used my older key)."
 }
 
 message JoinRequest {
@@ -36,6 +37,7 @@ message JoinRequest {
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
+  string group_key = 6;         // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101
 }
 
 message MessageFromLeader {

+ 16 - 1
hivemind/utils/mpfuture.py

@@ -4,7 +4,6 @@ import asyncio
 import concurrent.futures._base as base
 from contextlib import nullcontext, suppress
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import threading
 import uuid
@@ -127,11 +126,27 @@ class MPFuture(base.Future, Generic[ResultType]):
                     )
                     cls._pipe_waiter_thread.start()
 
+    @classmethod
+    def reset_backend(cls):
+        """
+        Reset the MPFuture backend. This is useful when the state may have been corrupted
+        (e.g. killing child processes may leave the locks acquired and the background thread blocked).
+
+        This method is neither thread-safe nor process-safe.
+        """
+
+        cls._initialization_lock = mp.Lock()
+        cls._update_lock = mp.Lock()
+        cls._active_pid = None
+
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         while True:
             try:
+                if cls._pipe_waiter_thread is not threading.current_thread():
+                    break  # Backend was reset, a new background thread has started
+
                 uid, msg_type, payload = receiver_pipe.recv()
                 future = None
                 future_ref = cls._active_futures.get(uid)

+ 5 - 1
tests/conftest.py

@@ -4,7 +4,8 @@ from contextlib import suppress
 import psutil
 import pytest
 
-from hivemind.utils import get_logger
+from hivemind.utils.logging import get_logger
+from hivemind.utils.mpfuture import MPFuture
 
 
 logger = get_logger(__name__)
@@ -26,3 +27,6 @@ def cleanup_children():
         for child in children:
             with suppress(psutil.NoSuchProcess):
                 child.kill()
+
+    # Broken code or killing of child processes may leave the MPFuture backend corrupted
+    MPFuture.reset_backend()