Browse Source

Merge branch 'ensure-equal-group-keys' into averager-libp2p

Aleksandr Borzunov 4 years ago
parent
commit
c524f9650c

+ 1 - 1
.github/workflows/run-tests.yml

@@ -87,6 +87,6 @@ jobs:
           pip install -e .
           pip install -e .
       - name: Test
       - name: Test
         run: |
         run: |
-          pytest --cov=hivemind tests -v
+          pytest --cov=hivemind -v tests
       - name: Upload coverage to Codecov
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1
         uses: codecov/codecov-action@v1

+ 4 - 0
hivemind/averaging/matchmaking.py

@@ -183,6 +183,7 @@ class Matchmaking:
                         expiration=expiration_time,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
                         gather=self.data_for_gather,
+                        group_key=self.group_key_manager.current_key,
                     )
                     )
                 ).__aiter__()
                 ).__aiter__()
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
@@ -322,11 +323,14 @@ class Matchmaking:
             or not isfinite(request.expiration)
             or not isfinite(request.expiration)
             or request_endpoint is None
             or request_endpoint is None
             or self.client_mode
             or self.client_mode
+            or not isinstance(request.group_key, GroupKey)
         ):
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
 
         elif request.schema_hash != self.schema_hash:
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif request.group_key != self.group_key_manager.current_key:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
         elif self.potential_leaders.declared_group_key is None:
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):

+ 2 - 0
hivemind/proto/averaging.proto

@@ -28,6 +28,7 @@ enum MessageCode {
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+  BAD_GROUP_KEY = 18;        // "I will not accept you. My current group key differs (maybe you used my older key)."
 }
 }
 
 
 message JoinRequest {
 message JoinRequest {
@@ -36,6 +37,7 @@ message JoinRequest {
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
+  string group_key = 6;         // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101
 }
 }
 
 
 message MessageFromLeader {
 message MessageFromLeader {

+ 16 - 0
hivemind/utils/mpfuture.py

@@ -127,11 +127,27 @@ class MPFuture(base.Future, Generic[ResultType]):
                     )
                     )
                     cls._pipe_waiter_thread.start()
                     cls._pipe_waiter_thread.start()
 
 
+    @classmethod
+    def reset_backend(cls):
+        """
+        Reset the MPFuture backend. This is useful when the state may have been corrupted
+        (e.g. killing child processes may leave the locks acquired and the background thread blocked).
+
+        This method is neither thread-safe nor process-safe.
+        """
+
+        cls._initialization_lock = mp.Lock()
+        cls._update_lock = mp.Lock()
+        cls._active_pid = None
+
     @classmethod
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         pid = os.getpid()
         while True:
         while True:
             try:
             try:
+                if cls._pipe_waiter_thread is not threading.current_thread():
+                    break  # Backend was reset, a new background thread has started
+
                 uid, msg_type, payload = receiver_pipe.recv()
                 uid, msg_type, payload = receiver_pipe.recv()
                 future = None
                 future = None
                 future_ref = cls._active_futures.get(uid)
                 future_ref = cls._active_futures.get(uid)

+ 2 - 3
tests/conftest.py

@@ -29,6 +29,5 @@ def cleanup_children():
             with suppress(psutil.NoSuchProcess):
             with suppress(psutil.NoSuchProcess):
                 child.kill()
                 child.kill()
 
 
-    # Killing child processes may leave the global MPFuture locks acquired, so we recreate them
-    MPFuture._initialization_lock = mp.Lock()
-    MPFuture._update_lock = mp.Lock()
+    # Killing child processes may leave the MPFuture backend broken
+    MPFuture.reset_backend()