|
@@ -3,26 +3,25 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
import contextlib
|
|
import contextlib
|
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
|
|
import random
|
|
import random
|
|
from math import isfinite
|
|
from math import isfinite
|
|
from typing import Optional, AsyncIterator, Set, Tuple, Dict
|
|
from typing import Optional, AsyncIterator, Set, Tuple, Dict
|
|
import concurrent.futures
|
|
import concurrent.futures
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
|
-import grpc
|
|
|
|
-import grpc._cython.cygrpc
|
|
|
|
-
|
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
|
|
from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
|
|
from hivemind.dht import DHT, DHTID, DHTExpiration
|
|
from hivemind.dht import DHT, DHTID, DHTExpiration
|
|
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
|
|
|
|
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
|
|
|
|
-from hivemind.utils.grpc import ChannelCache
|
|
|
|
|
|
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
|
|
|
|
+from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
|
|
|
|
+from hivemind.utils.asyncio import anext
|
|
|
|
+from hivemind.proto import averaging_pb2
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
+class Matchmaking:
|
|
f"""
|
|
f"""
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
@@ -37,7 +36,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- endpoint: Endpoint,
|
|
|
|
|
|
+ p2p: P2P,
|
|
schema_hash: bytes,
|
|
schema_hash: bytes,
|
|
dht: DHT,
|
|
dht: DHT,
|
|
*,
|
|
*,
|
|
@@ -57,8 +56,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
)
|
|
)
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.endpoint, self.schema_hash = endpoint, schema_hash
|
|
|
|
- self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
|
|
|
|
|
|
+ self._p2p = p2p
|
|
|
|
+ self.endpoint = p2p.id
|
|
|
|
+ self.schema_hash = schema_hash
|
|
|
|
+ self.group_key_manager = GroupKeyManager(dht, self.endpoint, prefix, initial_group_bits, target_group_size)
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
self.client_mode = client_mode
|
|
self.client_mode = client_mode
|
|
@@ -71,7 +72,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
- self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
|
|
|
|
|
|
+ self.potential_leaders = PotentialLeaders(self.endpoint, averaging_expiration, target_group_size)
|
|
self.data_for_gather: Optional[bytes] = None
|
|
self.data_for_gather: Optional[bytes] = None
|
|
|
|
|
|
@property
|
|
@property
|
|
@@ -169,20 +170,23 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
"""
|
|
"""
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
- call: Optional[grpc.aio.UnaryStreamCall] = None
|
|
|
|
|
|
+ stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
|
|
try:
|
|
try:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
- leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
|
- call = leader_stub.rpc_join_group(
|
|
|
|
|
|
+ from hivemind.averaging.averager import DecentralizedAverager
|
|
|
|
+
|
|
|
|
+ leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
|
|
|
|
+
|
|
|
|
+ stream = leader_stub.rpc_join_group(
|
|
averaging_pb2.JoinRequest(
|
|
averaging_pb2.JoinRequest(
|
|
- endpoint=self.endpoint,
|
|
|
|
|
|
+ endpoint=self.endpoint.to_base58(),
|
|
schema_hash=self.schema_hash,
|
|
schema_hash=self.schema_hash,
|
|
expiration=expiration_time,
|
|
expiration=expiration_time,
|
|
client_mode=self.client_mode,
|
|
client_mode=self.client_mode,
|
|
gather=self.data_for_gather,
|
|
gather=self.data_for_gather,
|
|
)
|
|
)
|
|
- )
|
|
|
|
- message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
|
|
|
|
|
|
+ ).__aiter__()
|
|
|
|
+ message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
|
|
logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
|
|
@@ -198,7 +202,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
async with self.potential_leaders.pause_search():
|
|
async with self.potential_leaders.pause_search():
|
|
time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
- message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
+ message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
@@ -208,7 +212,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
if message.suggested_leader and message.suggested_leader != self.endpoint:
|
|
if message.suggested_leader and message.suggested_leader != self.endpoint:
|
|
logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
|
|
logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
|
|
self.current_leader = None
|
|
self.current_leader = None
|
|
- call.cancel()
|
|
|
|
|
|
+ await stream.aclose()
|
|
return await self.request_join_group(message.suggested_leader, expiration_time)
|
|
return await self.request_join_group(message.suggested_leader, expiration_time)
|
|
else:
|
|
else:
|
|
logger.debug(f"{self} - leader disbanded group")
|
|
logger.debug(f"{self} - leader disbanded group")
|
|
@@ -218,23 +222,22 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
return None
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
- if call is not None:
|
|
|
|
- call.cancel()
|
|
|
|
return None
|
|
return None
|
|
- except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
|
|
|
|
|
|
+ except (P2PHandlerError, StopAsyncIteration) as e:
|
|
logger.error(f"{self} - failed to request potential leader {leader}: {e}")
|
|
logger.error(f"{self} - failed to request potential leader {leader}: {e}")
|
|
return None
|
|
return None
|
|
|
|
|
|
finally:
|
|
finally:
|
|
self.was_accepted_to_group.clear()
|
|
self.was_accepted_to_group.clear()
|
|
self.current_leader = None
|
|
self.current_leader = None
|
|
- if call is not None:
|
|
|
|
- await call.code()
|
|
|
|
|
|
+ if stream is not None:
|
|
|
|
+ await stream.aclose()
|
|
|
|
|
|
async def rpc_join_group(
|
|
async def rpc_join_group(
|
|
- self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
|
|
|
|
+ self, request: averaging_pb2.JoinRequest, _: P2PContext
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
"""accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
|
|
"""accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
|
|
|
|
+ request_endpoint = PeerID.from_base58(request.endpoint)
|
|
try:
|
|
try:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
reason_to_reject = self._check_reasons_to_reject(request)
|
|
reason_to_reject = self._check_reasons_to_reject(request)
|
|
@@ -242,7 +245,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
yield reason_to_reject
|
|
yield reason_to_reject
|
|
return
|
|
return
|
|
|
|
|
|
- self.current_followers[request.endpoint] = request
|
|
|
|
|
|
+ self.current_followers[request_endpoint] = request
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
|
|
|
|
if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
@@ -270,7 +273,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
self.was_accepted_to_group.is_set()
|
|
self.was_accepted_to_group.is_set()
|
|
or not self.assembled_group.done()
|
|
or not self.assembled_group.done()
|
|
or self.assembled_group.cancelled()
|
|
or self.assembled_group.cancelled()
|
|
- or request.endpoint not in self.assembled_group.result()
|
|
|
|
|
|
+ or request_endpoint not in self.assembled_group.result()
|
|
):
|
|
):
|
|
if self.current_leader is not None:
|
|
if self.current_leader is not None:
|
|
# outcome 3: found by a leader with higher priority, send our followers to him
|
|
# outcome 3: found by a leader with higher priority, send our followers to him
|
|
@@ -296,7 +299,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
|
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
- self.current_followers.pop(request.endpoint, None)
|
|
|
|
|
|
+ self.current_followers.pop(request_endpoint, None)
|
|
self.follower_was_discarded.set()
|
|
self.follower_was_discarded.set()
|
|
|
|
|
|
def _check_reasons_to_reject(
|
|
def _check_reasons_to_reject(
|