|
@@ -9,6 +9,8 @@ import random
|
|
from math import isfinite
|
|
from math import isfinite
|
|
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
|
|
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
|
|
|
|
|
|
|
|
+import numpy as np
|
|
|
|
+
|
|
from hivemind.averaging.control import StepControl
|
|
from hivemind.averaging.control import StepControl
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
|
|
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
|
|
@@ -44,16 +46,16 @@ class Matchmaking:
|
|
prefix: str,
|
|
prefix: str,
|
|
target_group_size: int,
|
|
target_group_size: int,
|
|
min_group_size: int,
|
|
min_group_size: int,
|
|
|
|
+ min_matchmaking_time: float,
|
|
request_timeout: float,
|
|
request_timeout: float,
|
|
client_mode: bool,
|
|
client_mode: bool,
|
|
initial_group_bits: str = "",
|
|
initial_group_bits: str = "",
|
|
- averaging_expiration: float = 15,
|
|
|
|
):
|
|
):
|
|
assert "." not in prefix, "group prefix must be a string without ."
|
|
assert "." not in prefix, "group prefix must be a string without ."
|
|
- if request_timeout is None or request_timeout >= averaging_expiration:
|
|
|
|
|
|
+ if request_timeout is None or request_timeout >= min_matchmaking_time:
|
|
logger.warning(
|
|
logger.warning(
|
|
- "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
|
|
|
|
- "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
|
|
|
|
|
|
+ "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
|
|
|
|
+ " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
|
|
)
|
|
)
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -68,7 +70,7 @@ class Matchmaking:
|
|
self.schema_hash = schema_hash
|
|
self.schema_hash = schema_hash
|
|
self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
|
|
self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
- self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
|
|
|
|
+ self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
|
|
self.client_mode = client_mode
|
|
self.client_mode = client_mode
|
|
|
|
|
|
self.lock_looking_for_group = asyncio.Lock()
|
|
self.lock_looking_for_group = asyncio.Lock()
|
|
@@ -79,11 +81,11 @@ class Matchmaking:
|
|
|
|
|
|
self.current_leader: Optional[PeerID] = None # iff i am a follower, this is a link to my current leader
|
|
self.current_leader: Optional[PeerID] = None # iff i am a follower, this is a link to my current leader
|
|
self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
- self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
|
|
|
|
|
|
+ self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
|
|
self.step: Optional[StepControl] = None
|
|
self.step: Optional[StepControl] = None
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
@contextlib.asynccontextmanager
|
|
- def looking_for_group(self, step: StepControl):
|
|
|
|
|
|
+ async def looking_for_group(self, step: StepControl):
|
|
async with self.lock_looking_for_group:
|
|
async with self.lock_looking_for_group:
|
|
assert self.step is None
|
|
assert self.step is None
|
|
self.step = step
|
|
self.step = step
|
|
@@ -121,7 +123,7 @@ class Matchmaking:
|
|
async with self.looking_for_group(step):
|
|
async with self.looking_for_group(step):
|
|
request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
|
|
request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
|
|
try:
|
|
try:
|
|
- return await asyncio.wait_for(self.assembled_group, timeout=step.timeout)
|
|
|
|
|
|
+ return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
return None
|
|
return None
|
|
|
|
|
|
@@ -144,15 +146,15 @@ class Matchmaking:
|
|
self.assembled_group = asyncio.Future()
|
|
self.assembled_group = asyncio.Future()
|
|
self.was_accepted_to_group.clear()
|
|
self.was_accepted_to_group.clear()
|
|
|
|
|
|
- async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
|
|
|
|
|
|
+ async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
|
|
"""Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
|
|
"""Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
|
|
assert self.is_looking_for_group
|
|
assert self.is_looking_for_group
|
|
- async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
|
|
|
|
|
|
+ async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
|
|
while True:
|
|
while True:
|
|
try:
|
|
try:
|
|
next_leader = await self.potential_leaders.pop_next_leader() # throws TimeoutError on expiration
|
|
next_leader = await self.potential_leaders.pop_next_leader() # throws TimeoutError on expiration
|
|
|
|
|
|
- group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
|
|
|
|
|
|
+ group = await self.request_join_group(next_leader)
|
|
if group is not None:
|
|
if group is not None:
|
|
return group
|
|
return group
|
|
|
|
|
|
@@ -173,26 +175,25 @@ class Matchmaking:
|
|
self.assembled_group.set_exception(e)
|
|
self.assembled_group.set_exception(e)
|
|
raise e
|
|
raise e
|
|
|
|
|
|
- async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
|
|
|
|
|
|
+ async def request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
|
|
"""
|
|
"""
|
|
:param leader: request this peer to be your leader for allreduce
|
|
:param leader: request this peer to be your leader for allreduce
|
|
- :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
|
|
|
|
:returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
|
|
:returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
|
|
:note: this function does not guarantee that your group leader is the same as :leader: parameter
|
|
:note: this function does not guarantee that your group leader is the same as :leader: parameter
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
"""
|
|
"""
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
- stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
|
|
|
|
|
|
+ stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
|
|
try:
|
|
try:
|
|
async with self.lock_request_join_group:
|
|
async with self.lock_request_join_group:
|
|
leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
|
|
leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
|
|
-
|
|
|
|
|
|
+ request_expiration_time = self.get_request_expiration_time()
|
|
stream = await leader_stub.rpc_join_group(
|
|
stream = await leader_stub.rpc_join_group(
|
|
averaging_pb2.JoinRequest(
|
|
averaging_pb2.JoinRequest(
|
|
schema_hash=self.schema_hash,
|
|
schema_hash=self.schema_hash,
|
|
- expiration=expiration_time,
|
|
|
|
|
|
+ expiration=request_expiration_time,
|
|
client_mode=self.client_mode,
|
|
client_mode=self.client_mode,
|
|
- gather=self.control.gather_binary,
|
|
|
|
|
|
+ gather=self.step.gather_binary,
|
|
group_key=self.group_key_manager.current_key,
|
|
group_key=self.group_key_manager.current_key,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
@@ -211,7 +212,7 @@ class Matchmaking:
|
|
return None
|
|
return None
|
|
|
|
|
|
async with self.potential_leaders.pause_search():
|
|
async with self.potential_leaders.pause_search():
|
|
- time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
|
|
|
|
+ time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
|
|
message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
|
|
message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
@@ -225,7 +226,7 @@ class Matchmaking:
|
|
logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
|
|
logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
|
|
self.current_leader = None
|
|
self.current_leader = None
|
|
await stream.aclose()
|
|
await stream.aclose()
|
|
- return await self.request_join_group(suggested_leader, expiration_time)
|
|
|
|
|
|
+ return await self.request_join_group(suggested_leader)
|
|
logger.debug(f"{self} - leader disbanded group")
|
|
logger.debug(f"{self} - leader disbanded group")
|
|
return None
|
|
return None
|
|
|
|
|
|
@@ -244,6 +245,14 @@ class Matchmaking:
|
|
if stream is not None:
|
|
if stream is not None:
|
|
await stream.aclose()
|
|
await stream.aclose()
|
|
|
|
|
|
|
|
+ def get_request_expiration_time(self) -> float:
|
|
|
|
+ """this averager's current expiration time - used to send join requests to leaders"""
|
|
|
|
+ if isfinite(self.potential_leaders.declared_expiration_time):
|
|
|
|
+ return self.potential_leaders.declared_expiration_time
|
|
|
|
+ else:
|
|
|
|
+ scheduled_time = max(self.step.scheduled_time, get_dht_time() + self.min_matchmaking_time)
|
|
|
|
+ return min(scheduled_time, self.potential_leaders.search_end_time)
|
|
|
|
+
|
|
async def rpc_join_group(
|
|
async def rpc_join_group(
|
|
self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
@@ -359,7 +368,7 @@ class Matchmaking:
|
|
random.shuffle(ordered_peer_ids)
|
|
random.shuffle(ordered_peer_ids)
|
|
|
|
|
|
gathered = tuple(
|
|
gathered = tuple(
|
|
- self.control.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
|
|
|
|
|
|
+ self.step.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
|
|
for peer_id in ordered_peer_ids
|
|
for peer_id in ordered_peer_ids
|
|
)
|
|
)
|
|
|
|
|
|
@@ -395,8 +404,8 @@ class Matchmaking:
|
|
class PotentialLeaders:
|
|
class PotentialLeaders:
|
|
"""An utility class that searches for averagers that could become our leaders"""
|
|
"""An utility class that searches for averagers that could become our leaders"""
|
|
|
|
|
|
- def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
|
|
|
|
- self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
|
|
|
|
|
|
+ def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
|
|
|
|
+ self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
|
|
self.target_group_size = target_group_size
|
|
self.target_group_size = target_group_size
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
@@ -411,7 +420,7 @@ class PotentialLeaders:
|
|
async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
|
|
async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
|
|
async with self.lock_search:
|
|
async with self.lock_search:
|
|
self.running.set()
|
|
self.running.set()
|
|
- self.search_end_time = get_dht_time() + step.timeout if step.timeout is not None else float("inf")
|
|
|
|
|
|
+ self.search_end_time = step.deadline if step.deadline is not None else float('inf')
|
|
update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
|
|
update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
|
|
if declare:
|
|
if declare:
|
|
declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
|
|
declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
|
|
@@ -474,20 +483,12 @@ class PotentialLeaders:
|
|
self.past_attempts.add((maybe_next_leader, entry.expiration_time))
|
|
self.past_attempts.add((maybe_next_leader, entry.expiration_time))
|
|
return maybe_next_leader
|
|
return maybe_next_leader
|
|
|
|
|
|
- @property
|
|
|
|
- def request_expiration_time(self) -> float:
|
|
|
|
- """this averager's current expiration time - used to send join requests to leaders"""
|
|
|
|
- if isfinite(self.declared_expiration_time):
|
|
|
|
- return self.declared_expiration_time
|
|
|
|
- else:
|
|
|
|
- return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
|
-
|
|
|
|
async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
|
|
async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
|
|
DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
while get_dht_time() < self.search_end_time:
|
|
while get_dht_time() < self.search_end_time:
|
|
new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
|
|
new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
|
|
self.max_assured_time = max(
|
|
self.max_assured_time = max(
|
|
- self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
|
|
|
|
|
|
+ self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
|
|
)
|
|
)
|
|
|
|
|
|
self.leader_queue.clear()
|
|
self.leader_queue.clear()
|
|
@@ -511,8 +512,9 @@ class PotentialLeaders:
|
|
try:
|
|
try:
|
|
while True:
|
|
while True:
|
|
await self.running.wait()
|
|
await self.running.wait()
|
|
- #TODO account for scheduled time here!
|
|
|
|
- new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
|
|
|
+ new_expiration_time = float(np.clip(step.scheduled_time,
|
|
|
|
+ a_min=get_dht_time() + self.min_matchmaking_time,
|
|
|
|
+ a_max=self.search_end_time))
|
|
self.declared_group_key = group_key = key_manager.current_key
|
|
self.declared_group_key = group_key = key_manager.current_key
|
|
self.declared_expiration_time = new_expiration_time
|
|
self.declared_expiration_time = new_expiration_time
|
|
self.declared_expiration.set()
|
|
self.declared_expiration.set()
|