|
@@ -2,27 +2,25 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
+import asyncio
|
|
|
|
+import concurrent.futures
|
|
import contextlib
|
|
import contextlib
|
|
import random
|
|
import random
|
|
from math import isfinite
|
|
from math import isfinite
|
|
-from typing import Optional, AsyncIterator, Set, Tuple, Dict
|
|
|
|
-import concurrent.futures
|
|
|
|
-import asyncio
|
|
|
|
-
|
|
|
|
-import grpc
|
|
|
|
-import grpc._cython.cygrpc
|
|
|
|
|
|
+from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
|
|
|
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
-from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
|
|
|
|
|
|
+from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
|
|
from hivemind.dht import DHT, DHTID, DHTExpiration
|
|
from hivemind.dht import DHT, DHTID, DHTExpiration
|
|
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
|
|
|
|
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
|
|
|
|
-from hivemind.utils.grpc import ChannelCache
|
|
|
|
|
|
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
|
+from hivemind.proto import averaging_pb2
|
|
|
|
+from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
|
|
|
|
+from hivemind.utils.asyncio import anext
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
+class Matchmaking:
|
|
f"""
|
|
f"""
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
@@ -37,10 +35,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- endpoint: Endpoint,
|
|
|
|
|
|
+ p2p: P2P,
|
|
schema_hash: bytes,
|
|
schema_hash: bytes,
|
|
dht: DHT,
|
|
dht: DHT,
|
|
*,
|
|
*,
|
|
|
|
+ servicer_type: Type[ServicerBase],
|
|
prefix: str,
|
|
prefix: str,
|
|
target_group_size: int,
|
|
target_group_size: int,
|
|
min_group_size: int,
|
|
min_group_size: int,
|
|
@@ -57,8 +56,16 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
)
|
|
)
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.endpoint, self.schema_hash = endpoint, schema_hash
|
|
|
|
- self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
|
|
|
|
|
|
+ self._p2p = p2p
|
|
|
|
+
|
|
|
|
+ if not issubclass(servicer_type, ServicerBase):
|
|
|
|
+ raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
|
|
|
|
+ self._servicer_type = servicer_type
|
|
|
|
+ self._prefix = prefix
|
|
|
|
+
|
|
|
|
+ self.peer_id = p2p.peer_id
|
|
|
|
+ self.schema_hash = schema_hash
|
|
|
|
+ self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
self.client_mode = client_mode
|
|
self.client_mode = client_mode
|
|
@@ -69,9 +76,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
self.was_accepted_to_group = asyncio.Event()
|
|
self.was_accepted_to_group = asyncio.Event()
|
|
self.assembled_group = asyncio.Future()
|
|
self.assembled_group = asyncio.Future()
|
|
|
|
|
|
- self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
|
|
- self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
|
|
- self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
|
|
|
|
|
|
+ self.current_leader: Optional[PeerID] = None # iff i am a follower, this is a link to my current leader
|
|
|
|
+ self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
|
|
+ self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
|
|
self.data_for_gather: Optional[bytes] = None
|
|
self.data_for_gather: Optional[bytes] = None
|
|
|
|
|
|
@property
|
|
@property
|
|
@@ -87,7 +94,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
lfg_status += f" leading {len(self.current_followers)} followers,"
|
|
lfg_status += f" leading {len(self.current_followers)} followers,"
|
|
schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
|
|
schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
|
|
return (
|
|
return (
|
|
- f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
|
|
|
|
|
|
+ f"{self.__class__.__name__}(peer_id={self.peer_id}, schema={schema_hash_repr}, {lfg_status}"
|
|
f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
|
|
f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
|
|
)
|
|
)
|
|
|
|
|
|
@@ -160,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
self.assembled_group.set_exception(e)
|
|
self.assembled_group.set_exception(e)
|
|
raise e
|
|
raise e
|
|
|
|
|
|
- async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
|
|
|
|
|
|
+ async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
|
|
"""
|
|
"""
|
|
:param leader: request this peer to be your leader for allreduce
|
|
:param leader: request this peer to be your leader for allreduce
|
|
:param expiration_time: inform leader that we intend to begin averaging before this expiration_time
|
|
:param expiration_time: inform leader that we intend to begin averaging before this expiration_time
|
|
@@ -169,23 +176,24 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
"""
|
|
"""
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
- call: Optional[grpc.aio.UnaryStreamCall] = None
|
|
|
|
|
|
+ stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
|
|
try:
|
|
try:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
- leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
|
- call = leader_stub.rpc_join_group(
|
|
|
|
|
|
+ leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
|
|
|
|
+
|
|
|
|
+ stream = leader_stub.rpc_join_group(
|
|
averaging_pb2.JoinRequest(
|
|
averaging_pb2.JoinRequest(
|
|
- endpoint=self.endpoint,
|
|
|
|
schema_hash=self.schema_hash,
|
|
schema_hash=self.schema_hash,
|
|
expiration=expiration_time,
|
|
expiration=expiration_time,
|
|
client_mode=self.client_mode,
|
|
client_mode=self.client_mode,
|
|
gather=self.data_for_gather,
|
|
gather=self.data_for_gather,
|
|
|
|
+ group_key=self.group_key_manager.current_key,
|
|
)
|
|
)
|
|
- )
|
|
|
|
- message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
|
|
|
|
|
|
+ ).__aiter__()
|
|
|
|
+ message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
- logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
|
|
|
|
|
|
+ logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
|
|
self.current_leader = leader
|
|
self.current_leader = leader
|
|
self.was_accepted_to_group.set()
|
|
self.was_accepted_to_group.set()
|
|
if len(self.current_followers) > 0:
|
|
if len(self.current_followers) > 0:
|
|
@@ -193,56 +201,55 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
if message.code != averaging_pb2.ACCEPTED:
|
|
if message.code != averaging_pb2.ACCEPTED:
|
|
code = averaging_pb2.MessageCode.Name(message.code)
|
|
code = averaging_pb2.MessageCode.Name(message.code)
|
|
- logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
|
|
|
|
|
|
+ logger.debug(f"{self.peer_id} - requested {leader} to be my leader, but got rejected with {code}")
|
|
return None
|
|
return None
|
|
|
|
|
|
async with self.potential_leaders.pause_search():
|
|
async with self.potential_leaders.pause_search():
|
|
time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
- message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
+ message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
return await self.follower_assemble_group(leader, message)
|
|
return await self.follower_assemble_group(leader, message)
|
|
|
|
|
|
if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
|
|
if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
|
|
- if message.suggested_leader and message.suggested_leader != self.endpoint:
|
|
|
|
- logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
|
|
|
|
- self.current_leader = None
|
|
|
|
- call.cancel()
|
|
|
|
- return await self.request_join_group(message.suggested_leader, expiration_time)
|
|
|
|
- else:
|
|
|
|
- logger.debug(f"{self} - leader disbanded group")
|
|
|
|
- return None
|
|
|
|
|
|
+ if message.suggested_leader:
|
|
|
|
+ suggested_leader = PeerID(message.suggested_leader)
|
|
|
|
+ if suggested_leader != self.peer_id:
|
|
|
|
+ logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
|
|
|
|
+ self.current_leader = None
|
|
|
|
+ await stream.aclose()
|
|
|
|
+ return await self.request_join_group(suggested_leader, expiration_time)
|
|
|
|
+ logger.debug(f"{self} - leader disbanded group")
|
|
|
|
+ return None
|
|
|
|
|
|
logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
|
|
logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
|
|
return None
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
- if call is not None:
|
|
|
|
- call.cancel()
|
|
|
|
return None
|
|
return None
|
|
- except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
|
|
|
|
|
|
+ except (P2PHandlerError, StopAsyncIteration) as e:
|
|
logger.error(f"{self} - failed to request potential leader {leader}: {e}")
|
|
logger.error(f"{self} - failed to request potential leader {leader}: {e}")
|
|
return None
|
|
return None
|
|
|
|
|
|
finally:
|
|
finally:
|
|
self.was_accepted_to_group.clear()
|
|
self.was_accepted_to_group.clear()
|
|
self.current_leader = None
|
|
self.current_leader = None
|
|
- if call is not None:
|
|
|
|
- await call.code()
|
|
|
|
|
|
+ if stream is not None:
|
|
|
|
+ await stream.aclose()
|
|
|
|
|
|
async def rpc_join_group(
|
|
async def rpc_join_group(
|
|
- self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
|
|
|
|
+ self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
"""accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
|
|
"""accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
|
|
try:
|
|
try:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
- reason_to_reject = self._check_reasons_to_reject(request)
|
|
|
|
|
|
+ reason_to_reject = self._check_reasons_to_reject(request, context)
|
|
if reason_to_reject is not None:
|
|
if reason_to_reject is not None:
|
|
yield reason_to_reject
|
|
yield reason_to_reject
|
|
return
|
|
return
|
|
|
|
|
|
- self.current_followers[request.endpoint] = request
|
|
|
|
|
|
+ self.current_followers[context.remote_id] = request
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
|
|
|
|
if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
@@ -270,12 +277,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
self.was_accepted_to_group.is_set()
|
|
self.was_accepted_to_group.is_set()
|
|
or not self.assembled_group.done()
|
|
or not self.assembled_group.done()
|
|
or self.assembled_group.cancelled()
|
|
or self.assembled_group.cancelled()
|
|
- or request.endpoint not in self.assembled_group.result()
|
|
|
|
|
|
+ or context.remote_id not in self.assembled_group.result()
|
|
):
|
|
):
|
|
if self.current_leader is not None:
|
|
if self.current_leader is not None:
|
|
# outcome 3: found by a leader with higher priority, send our followers to him
|
|
# outcome 3: found by a leader with higher priority, send our followers to him
|
|
yield averaging_pb2.MessageFromLeader(
|
|
yield averaging_pb2.MessageFromLeader(
|
|
- code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader
|
|
|
|
|
|
+ code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_bytes()
|
|
)
|
|
)
|
|
return
|
|
return
|
|
else:
|
|
else:
|
|
@@ -286,7 +293,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
yield averaging_pb2.MessageFromLeader(
|
|
yield averaging_pb2.MessageFromLeader(
|
|
code=averaging_pb2.BEGIN_ALLREDUCE,
|
|
code=averaging_pb2.BEGIN_ALLREDUCE,
|
|
group_id=group_info.group_id,
|
|
group_id=group_info.group_id,
|
|
- ordered_group_endpoints=group_info.endpoints,
|
|
|
|
|
|
+ ordered_peer_ids=[item.to_bytes() for item in group_info.peer_ids],
|
|
gathered=group_info.gathered,
|
|
gathered=group_info.gathered,
|
|
)
|
|
)
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
@@ -296,11 +303,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
|
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
- self.current_followers.pop(request.endpoint, None)
|
|
|
|
|
|
+ self.current_followers.pop(context.remote_id, None)
|
|
self.follower_was_discarded.set()
|
|
self.follower_was_discarded.set()
|
|
|
|
|
|
def _check_reasons_to_reject(
|
|
def _check_reasons_to_reject(
|
|
- self, request: averaging_pb2.JoinRequest
|
|
|
|
|
|
+ self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
) -> Optional[averaging_pb2.MessageFromLeader]:
|
|
) -> Optional[averaging_pb2.MessageFromLeader]:
|
|
""":returns: if accepted, return None, otherwise return a reason for rejection"""
|
|
""":returns: if accepted, return None, otherwise return a reason for rejection"""
|
|
if not self.is_looking_for_group or self.assembled_group.done():
|
|
if not self.is_looking_for_group or self.assembled_group.done():
|
|
@@ -312,24 +319,25 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
or len(request.schema_hash) == 0
|
|
or len(request.schema_hash) == 0
|
|
or not isinstance(request.expiration, DHTExpiration)
|
|
or not isinstance(request.expiration, DHTExpiration)
|
|
or not isfinite(request.expiration)
|
|
or not isfinite(request.expiration)
|
|
- or not isinstance(request.endpoint, Endpoint)
|
|
|
|
- or len(request.endpoint) == 0
|
|
|
|
or self.client_mode
|
|
or self.client_mode
|
|
|
|
+ or not isinstance(request.group_key, GroupKey)
|
|
):
|
|
):
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
|
|
|
|
|
|
elif request.schema_hash != self.schema_hash:
|
|
elif request.schema_hash != self.schema_hash:
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
|
|
|
|
+ elif request.group_key != self.group_key_manager.current_key:
|
|
|
|
+ return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
|
|
elif self.potential_leaders.declared_group_key is None:
|
|
elif self.potential_leaders.declared_group_key is None:
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
|
|
elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
|
|
elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
|
|
elif self.current_leader is not None:
|
|
elif self.current_leader is not None:
|
|
return averaging_pb2.MessageFromLeader(
|
|
return averaging_pb2.MessageFromLeader(
|
|
- code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
|
|
|
|
- ) # note: this suggested leader is currently ignored
|
|
|
|
- elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
|
|
|
|
- return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
|
|
|
|
|
|
+ code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_bytes()
|
|
|
|
+ )
|
|
|
|
+ elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
|
|
|
|
+ return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
|
|
elif len(self.current_followers) + 1 >= self.target_group_size:
|
|
elif len(self.current_followers) + 1 >= self.target_group_size:
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
|
|
else:
|
|
else:
|
|
@@ -339,34 +347,35 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
"""Form up all current followers into a group and gather metadata"""
|
|
"""Form up all current followers into a group and gather metadata"""
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
|
|
assert not self.assembled_group.done()
|
|
assert not self.assembled_group.done()
|
|
- group_id = DHTID.generate().to_bytes() # note: both groupd_id and the order of endpoints must be random
|
|
|
|
- ordered_group_endpoints = list(self.current_followers)
|
|
|
|
- ordered_group_endpoints.append(self.endpoint)
|
|
|
|
- random.shuffle(ordered_group_endpoints)
|
|
|
|
|
|
+ group_id = DHTID.generate().to_bytes() # note: both groupd_id and the order of peer_ids must be random
|
|
|
|
+ ordered_peer_ids = list(self.current_followers)
|
|
|
|
+ ordered_peer_ids.append(self.peer_id)
|
|
|
|
+ random.shuffle(ordered_peer_ids)
|
|
|
|
|
|
gathered = tuple(
|
|
gathered = tuple(
|
|
- self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
|
|
|
|
- for endpoint in ordered_group_endpoints
|
|
|
|
|
|
+ self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
|
|
|
|
+ for peer_id in ordered_peer_ids
|
|
)
|
|
)
|
|
|
|
|
|
- logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
|
|
|
|
- group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
|
|
|
|
|
|
+ logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
|
|
|
|
+ group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
|
|
await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
|
|
await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
|
|
self.assembled_group.set_result(group_info)
|
|
self.assembled_group.set_result(group_info)
|
|
return group_info
|
|
return group_info
|
|
|
|
|
|
- async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
|
|
|
|
|
|
+ async def follower_assemble_group(self, leader: PeerID, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
|
|
"""Form a group from using peers and metadata provided by our leader"""
|
|
"""Form a group from using peers and metadata provided by our leader"""
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
assert not self.assembled_group.done()
|
|
assert not self.assembled_group.done()
|
|
assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
|
|
assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
|
|
|
|
|
|
- group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints
|
|
|
|
- assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
|
|
|
|
- assert len(ordered_group_endpoints) == len(msg.gathered)
|
|
|
|
|
|
+ group_id = msg.group_id
|
|
|
|
+ ordered_peer_ids = [PeerID(item) for item in msg.ordered_peer_ids]
|
|
|
|
+ assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
|
|
|
|
+ assert len(ordered_peer_ids) == len(msg.gathered)
|
|
|
|
|
|
- logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.")
|
|
|
|
- group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered))
|
|
|
|
|
|
+ logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
|
|
|
|
+ group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
|
|
await self.group_key_manager.update_key_on_group_assembled(group_info)
|
|
await self.group_key_manager.update_key_on_group_assembled(group_info)
|
|
self.assembled_group.set_result(group_info)
|
|
self.assembled_group.set_result(group_info)
|
|
return group_info
|
|
return group_info
|
|
@@ -380,13 +389,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
class PotentialLeaders:
|
|
class PotentialLeaders:
|
|
"""An utility class that searches for averagers that could become our leaders"""
|
|
"""An utility class that searches for averagers that could become our leaders"""
|
|
|
|
|
|
- def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
|
|
|
|
- self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
|
|
|
|
|
|
+ def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
|
|
|
|
+ self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
|
|
self.target_group_size = target_group_size
|
|
self.target_group_size = target_group_size
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
- self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
|
|
|
|
- self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
|
|
|
|
|
|
+ self.leader_queue = TimedStorage[PeerID, DHTExpiration]()
|
|
|
|
+ self.past_attempts: Set[Tuple[PeerID, DHTExpiration]] = set()
|
|
self.declared_expiration_time = float("inf")
|
|
self.declared_expiration_time = float("inf")
|
|
self.declared_group_key: Optional[GroupKey] = None
|
|
self.declared_group_key: Optional[GroupKey] = None
|
|
self.max_assured_time = float("-inf")
|
|
self.max_assured_time = float("-inf")
|
|
@@ -433,7 +442,7 @@ class PotentialLeaders:
|
|
else:
|
|
else:
|
|
self.running.clear()
|
|
self.running.clear()
|
|
|
|
|
|
- async def pop_next_leader(self) -> Endpoint:
|
|
|
|
|
|
+ async def pop_next_leader(self) -> PeerID:
|
|
"""Remove and return the next most suitable leader or throw an exception if reached timeout"""
|
|
"""Remove and return the next most suitable leader or throw an exception if reached timeout"""
|
|
assert self.running.is_set(), "Not running search at the moment"
|
|
assert self.running.is_set(), "Not running search at the moment"
|
|
while True:
|
|
while True:
|
|
@@ -442,9 +451,9 @@ class PotentialLeaders:
|
|
if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
|
|
if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
|
|
self.update_triggered.set()
|
|
self.update_triggered.set()
|
|
|
|
|
|
- if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
|
|
|
|
|
|
+ if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_bytes()) > (
|
|
self.declared_expiration_time,
|
|
self.declared_expiration_time,
|
|
- self.endpoint,
|
|
|
|
|
|
+ self.peer_id.to_bytes(),
|
|
):
|
|
):
|
|
await asyncio.wait(
|
|
await asyncio.wait(
|
|
{self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
|
|
{self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
|
|
@@ -479,7 +488,7 @@ class PotentialLeaders:
|
|
|
|
|
|
self.leader_queue.clear()
|
|
self.leader_queue.clear()
|
|
for peer, peer_expiration_time in new_peers:
|
|
for peer, peer_expiration_time in new_peers:
|
|
- if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
|
|
|
+ if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
|
|
continue
|
|
continue
|
|
self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
@@ -495,7 +504,7 @@ class PotentialLeaders:
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
return # note: this is a compatibility layer for python3.7
|
|
return # note: this is a compatibility layer for python3.7
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
|
|
|
|
|
|
+ logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
|
|
raise
|
|
raise
|
|
|
|
|
|
async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
|
|
async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
|
|
@@ -508,21 +517,21 @@ class PotentialLeaders:
|
|
self.declared_group_key = group_key = key_manager.current_key
|
|
self.declared_group_key = group_key = key_manager.current_key
|
|
self.declared_expiration_time = new_expiration_time
|
|
self.declared_expiration_time = new_expiration_time
|
|
self.declared_expiration.set()
|
|
self.declared_expiration.set()
|
|
- await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
|
|
|
|
|
|
+ await key_manager.declare_averager(group_key, self.peer_id, expiration_time=new_expiration_time)
|
|
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
|
|
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
|
|
if self.running.is_set() and len(self.leader_queue) == 0:
|
|
if self.running.is_set() and len(self.leader_queue) == 0:
|
|
await key_manager.update_key_on_not_enough_peers()
|
|
await key_manager.update_key_on_not_enough_peers()
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
pass # note: this is a compatibility layer for python3.7
|
|
pass # note: this is a compatibility layer for python3.7
|
|
except Exception as e: # note: we catch exceptions here because otherwise they are never printed
|
|
except Exception as e: # note: we catch exceptions here because otherwise they are never printed
|
|
- logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
|
|
|
|
|
|
+ logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
|
|
finally:
|
|
finally:
|
|
if self.declared_group_key is not None:
|
|
if self.declared_group_key is not None:
|
|
prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
|
|
prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
|
|
self.declared_group_key, self.declared_expiration_time = None, float("inf")
|
|
self.declared_group_key, self.declared_expiration_time = None, float("inf")
|
|
- self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
|
|
|
|
|
|
+ self.leader_queue, self.max_assured_time = TimedStorage[PeerID, DHTExpiration](), float("-inf")
|
|
await key_manager.declare_averager(
|
|
await key_manager.declare_averager(
|
|
- prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
|
|
|
|
|
|
+ prev_declared_key, self.peer_id, prev_expiration_time, looking_for_group=False
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|