Browse Source

Ensure group key equality in rpc_join_group()

Aleksandr Borzunov 4 years ago
parent
commit
868be1b756
2 changed files with 6 additions and 0 deletions
  1. 4 0
      hivemind/averaging/matchmaking.py
  2. 2 0
      hivemind/proto/averaging.proto

+ 4 - 0
hivemind/averaging/matchmaking.py

@@ -180,6 +180,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
+                        group_key=self.group_key_manager.current_key,
                     )
                 )
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
@@ -315,11 +316,14 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             or not isinstance(request.endpoint, Endpoint)
             or len(request.endpoint) == 0
             or self.client_mode
+            or not isinstance(request.group_key, GroupKey)
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif request.group_key != self.group_key_manager.current_key:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):

+ 2 - 0
hivemind/proto/averaging.proto

@@ -28,6 +28,7 @@ enum MessageCode {
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+  BAD_GROUP_KEY = 18;
 }
 
 message JoinRequest {
@@ -36,6 +37,7 @@ message JoinRequest {
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
+  string group_key = 6;
 }
 
 message MessageFromLeader {