|
@@ -6,7 +6,7 @@ import contextlib
|
|
|
import random
|
|
|
from dataclasses import asdict
|
|
|
from math import isfinite
|
|
|
-from typing import Sequence, Optional, AsyncIterator, Set
|
|
|
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple
|
|
|
import asyncio
|
|
|
|
|
|
import torch
|
|
@@ -27,29 +27,41 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
f"""
|
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
|
+
|
|
|
+ :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
|
|
|
+ This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
|
|
|
+ In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
|
|
|
+ This deadlock only happens if averagers have outdated information on expirations (due to network delays).
|
|
|
+ While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
|
|
|
+ Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
|
|
|
+
|
|
|
"""
|
|
|
|
|
|
def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
|
|
|
prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
|
|
|
- averaging_expiration: float = 15, compression_type: runtime_pb2.CompressionType = runtime_pb2.NONE):
|
|
|
+ averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
|
|
|
assert '.' not in prefix, "group prefix must be a string without ."
|
|
|
+ if request_timeout is None or request_timeout >= averaging_expiration:
|
|
|
+ logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
|
|
|
+ "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
|
|
|
|
|
|
super().__init__()
|
|
|
self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
|
|
|
self.prefix, self.group_bits = prefix, initial_group_bits
|
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
|
- self.averaging_expiration, self.compression_type = averaging_expiration, compression_type
|
|
|
-
|
|
|
+ self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
|
+ self.allreduce_kwargs = allreduce_kwargs
|
|
|
self.schema_hash = compute_schema_hash(self.averaged_tensors)
|
|
|
|
|
|
self.lock_looking_for_group = asyncio.Lock()
|
|
|
self.lock_request_join_group = asyncio.Lock()
|
|
|
- self.cond_notify_followers = asyncio.Condition()
|
|
|
+ self.follower_was_discarded = asyncio.Event()
|
|
|
+ self.was_accepted_to_group = asyncio.Event()
|
|
|
self.assembled_group = asyncio.Future()
|
|
|
|
|
|
self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
|
self.current_followers: Set[Endpoint] = set() # iff i am a leader, this contains my followers excluding myself
|
|
|
- self.potential_leaders = PotentialLeaders(self.endpoint, self.dht, self.averaging_expiration)
|
|
|
+ self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
|
|
|
|
|
|
@property
|
|
|
def is_looking_for_group(self):
|
|
@@ -70,7 +82,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
|
|
|
f" current key = {self.current_group_key})"
|
|
|
|
|
|
- async def look_for_group(self, *, timeout: Optional[float] = None) -> AllReduceRunner:
|
|
|
+ async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
|
|
|
"""
|
|
|
:returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
|
|
|
Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
|
|
@@ -82,48 +94,58 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
|
|
|
try:
|
|
|
return await asyncio.wait_for(self.assembled_group, timeout=timeout)
|
|
|
- except Exception as e:
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ return None
|
|
|
+
|
|
|
+ except BaseException as e:
|
|
|
if len(self.current_followers) > 0:
|
|
|
async with self.lock_request_join_group:
|
|
|
await self.leader_disband_group()
|
|
|
- self.assembled_group.set_exception(e)
|
|
|
+ if not self.assembled_group.done():
|
|
|
+ self.assembled_group.set_exception(e)
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
if not request_leaders_task.done():
|
|
|
request_leaders_task.cancel()
|
|
|
- if self.assembled_group.done():
|
|
|
- self.assembled_group = asyncio.Future()
|
|
|
+ if not self.assembled_group.done():
|
|
|
+ self.assembled_group.cancel()
|
|
|
+ while len(self.current_followers) > 0:
|
|
|
+ await self.follower_was_discarded.wait()
|
|
|
+ self.follower_was_discarded.clear()
|
|
|
+ # note: the code above ensures that we send all followers away before creating new future
|
|
|
+ self.assembled_group = asyncio.Future()
|
|
|
+ self.was_accepted_to_group.clear()
|
|
|
|
|
|
async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
|
|
|
""" Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
|
|
|
- end_time = get_dht_time() + timeout if timeout is not None else float('inf')
|
|
|
async with self.potential_leaders.begin_search(self.current_group_key, timeout):
|
|
|
# TODO update group_bits on success! reduce number of bits on not enough peers.
|
|
|
# TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
|
|
|
# (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
|
|
|
while True:
|
|
|
try:
|
|
|
- time_to_expiration = self.potential_leaders.declared_expiration_time - get_dht_time()
|
|
|
- next_best_leader = await asyncio.wait_for(
|
|
|
- self.potential_leaders.pop_next_leader(),
|
|
|
- timeout=time_to_expiration if isfinite(time_to_expiration) else None)
|
|
|
-
|
|
|
- request_expiration_time = min(self.potential_leaders.declared_expiration_time,
|
|
|
- end_time, get_dht_time() + self.averaging_expiration)
|
|
|
- group = await self.request_join_group(next_best_leader, request_expiration_time)
|
|
|
+ next_leader = await self.potential_leaders.pop_next_leader() # throws TimeoutError on expiration
|
|
|
+
|
|
|
+ group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
|
|
|
if group is not None:
|
|
|
return group
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
async with self.lock_request_join_group:
|
|
|
- if len(self.current_followers) >= self.min_group_size:
|
|
|
+ if self.assembled_group.done():
|
|
|
+ return self.assembled_group.result()
|
|
|
+ elif len(self.current_followers) + 1 >= self.min_group_size:
|
|
|
# the time is up, we have a *good enough* group. run allreduce as is.
|
|
|
return await self.leader_assemble_group()
|
|
|
- else:
|
|
|
+ elif len(self.current_followers) > 0:
|
|
|
await self.leader_disband_group()
|
|
|
# TODO maybe adjust grid size
|
|
|
- continue
|
|
|
+ continue
|
|
|
+ except Exception as e:
|
|
|
+ if not self.assembled_group.done():
|
|
|
+ self.assembled_group.set_exception(e)
|
|
|
+ raise e
|
|
|
|
|
|
async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
|
|
|
"""
|
|
@@ -134,87 +156,101 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
The originally specified leader can disband group and redirect us to a different leader
|
|
|
"""
|
|
|
assert self.is_looking_for_group and self.current_leader is None
|
|
|
- call: Optional[grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]] = None
|
|
|
+ call: Optional[grpc.aio.UnaryStreamCall] = None
|
|
|
try:
|
|
|
async with self.lock_request_join_group:
|
|
|
leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
|
|
|
endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
|
|
|
+ message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
|
|
|
|
|
|
- message = await call.read()
|
|
|
- if message.code != averaging_pb2.ACCEPTED:
|
|
|
- code = averaging_pb2.MessageCode.Name(message.code)
|
|
|
- logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
|
|
|
- return None
|
|
|
+ if message.code == averaging_pb2.ACCEPTED:
|
|
|
+ logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
|
|
|
+ self.current_leader = leader
|
|
|
+ self.was_accepted_to_group.set()
|
|
|
+ if len(self.current_followers) > 0:
|
|
|
+ await self.leader_disband_group()
|
|
|
|
|
|
- # else: we were accepted
|
|
|
- logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
|
|
|
- self.current_leader = leader
|
|
|
- if len(self.current_followers) > 0:
|
|
|
- await self.leader_disband_group()
|
|
|
+ if message.code != averaging_pb2.ACCEPTED:
|
|
|
+ code = averaging_pb2.MessageCode.Name(message.code)
|
|
|
+ logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
|
|
|
+ return None
|
|
|
|
|
|
async with self.potential_leaders.pause_search():
|
|
|
- message = await call.read()
|
|
|
+ time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
|
|
|
+ message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
|
|
|
|
|
|
- if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
|
- async with self.lock_request_join_group:
|
|
|
- return await self.follower_assemble_group(leader, message.group_id, message.ordered_group_endpoints)
|
|
|
- elif message.code == averaging_pb2.GROUP_DISBANDED and bool(message.suggested_leader):
|
|
|
- logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
|
|
|
- return await self.request_join_group(message.suggested_leader, expiration_time)
|
|
|
+ if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
|
+ async with self.lock_request_join_group:
|
|
|
+ return await self.follower_assemble_group(
|
|
|
+ leader, message.group_id, message.ordered_group_endpoints)
|
|
|
+
|
|
|
+ if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
|
|
|
+ if message.suggested_leader and message.suggested_leader != self.endpoint:
|
|
|
+ logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
|
|
|
+ self.current_leader = None
|
|
|
+ call.cancel()
|
|
|
+ return await self.request_join_group(message.suggested_leader, expiration_time)
|
|
|
+ else:
|
|
|
+ logger.debug(f"{self} - leader disbanded group")
|
|
|
+ return None
|
|
|
|
|
|
- else:
|
|
|
- logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
|
|
|
- return None
|
|
|
+ logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
|
|
|
+ return None
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
|
+ if call is not None:
|
|
|
+ call.cancel()
|
|
|
+ return None
|
|
|
finally:
|
|
|
+ self.was_accepted_to_group.clear()
|
|
|
self.current_leader = None
|
|
|
if call is not None:
|
|
|
- call.cancel()
|
|
|
+ await call.code()
|
|
|
|
|
|
async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
|
""" accept or reject a join request from another averager; if accepted, run him through allreduce steps """
|
|
|
try:
|
|
|
- reason_to_reject = self._check_reasons_to_reject(request)
|
|
|
- if reason_to_reject is not None:
|
|
|
- yield reason_to_reject
|
|
|
- return
|
|
|
-
|
|
|
- current_group = self.assembled_group # copy current assembled_group to avoid overwriting
|
|
|
async with self.lock_request_join_group:
|
|
|
+ reason_to_reject = self._check_reasons_to_reject(request)
|
|
|
+ if reason_to_reject is not None:
|
|
|
+ yield reason_to_reject
|
|
|
+ return
|
|
|
+
|
|
|
self.current_followers.add(request.endpoint)
|
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
|
|
|
|
- if len(self.current_followers) + 1 >= self.target_group_size:
|
|
|
+ if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
|
# outcome 1: we have assembled a full group and are ready for allreduce
|
|
|
await self.leader_assemble_group()
|
|
|
|
|
|
- if not current_group.done():
|
|
|
- try:
|
|
|
- async with self.cond_notify_followers:
|
|
|
- # wait for the group to be assembled or disbanded
|
|
|
- timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
|
|
|
- await asyncio.wait_for(self.cond_notify_followers.wait(), timeout=timeout)
|
|
|
- except asyncio.TimeoutError:
|
|
|
- async with self.lock_request_join_group:
|
|
|
+ # wait for the group to be assembled or disbanded
|
|
|
+ timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
|
|
|
+ await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
|
|
|
+ return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
|
|
|
+ if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
|
|
|
+ async with self.lock_request_join_group:
|
|
|
+ if self.assembled_group.done():
|
|
|
+ pass # this covers a rare case when the group is assembled while the event loop was busy.
|
|
|
+ elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
|
|
|
# outcome 2: the time is up, run allreduce with what we have or disband
|
|
|
- if len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
|
|
|
- await self.leader_assemble_group()
|
|
|
- else:
|
|
|
- await self.leader_disband_group()
|
|
|
-
|
|
|
- if self.current_leader is not None:
|
|
|
- # outcome 3: found by a leader with higher priority, send our followers to him
|
|
|
- yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
|
|
|
- suggested_leader=self.current_leader)
|
|
|
- return
|
|
|
+ await self.leader_assemble_group()
|
|
|
+ else:
|
|
|
+ await self.leader_disband_group()
|
|
|
|
|
|
- if request.endpoint not in self.current_followers:
|
|
|
- yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
|
|
|
- return
|
|
|
+ if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
|
|
|
+ or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
|
|
|
+ if self.current_leader is not None:
|
|
|
+ # outcome 3: found by a leader with higher priority, send our followers to him
|
|
|
+ yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
|
|
|
+ suggested_leader=self.current_leader)
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
|
|
|
+ return
|
|
|
|
|
|
- # finally, run allreduce
|
|
|
- allreduce_group = current_group.result()
|
|
|
+ allreduce_group = self.assembled_group.result()
|
|
|
yield averaging_pb2.MessageFromLeader(
|
|
|
code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
|
|
|
ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
|
|
@@ -225,10 +261,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
|
self.current_followers.discard(request.endpoint)
|
|
|
+ self.follower_was_discarded.set()
|
|
|
|
|
|
- def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> averaging_pb2.MessageFromLeader:
|
|
|
+ def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
|
|
|
""" :returns: if accepted, return None, otherwise return a reason for rejection """
|
|
|
- if not self.is_looking_for_group:
|
|
|
+ if not self.is_looking_for_group or self.assembled_group.done():
|
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
|
|
|
|
|
|
if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
|
|
@@ -243,8 +280,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
|
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
|
|
|
elif self.current_leader is not None:
|
|
|
- return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
|
|
|
- suggested_leader=self.current_leader)
|
|
|
+ return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
|
|
|
+ ) # note: this suggested leader is currently ignored
|
|
|
elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
|
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
|
|
|
elif len(self.current_followers) + 1 >= self.target_group_size:
|
|
@@ -255,68 +292,71 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
async def leader_assemble_group(self) -> AllReduceRunner:
|
|
|
""" Form up all current followers into a group and prepare to _run_allreduce """
|
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
|
+ assert not self.assembled_group.done()
|
|
|
group_id = DHTID.generate().to_bytes()
|
|
|
ordered_group_endpoints = list(self.current_followers)
|
|
|
ordered_group_endpoints.append(self.endpoint)
|
|
|
random.shuffle(ordered_group_endpoints)
|
|
|
- logger.debug(f"{self.endpoint} - leader started allreduce with {len(ordered_group_endpoints)} followers.")
|
|
|
- allreduce_group = AllReduceRunner(
|
|
|
- group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
- ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
|
|
|
+ logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
|
|
|
+ allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
+ ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
- async with self.cond_notify_followers:
|
|
|
- self.cond_notify_followers.notify_all()
|
|
|
return allreduce_group
|
|
|
|
|
|
async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
|
|
|
ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
|
|
|
""" Prepare to run allreduce using a list of peers provided by our leader """
|
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
|
+ assert not self.assembled_group.done()
|
|
|
logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
|
|
|
assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
|
|
|
assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
|
|
|
- allreduce_group = AllReduceRunner(
|
|
|
- group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
- ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
|
|
|
+ allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
+ ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
- async with self.cond_notify_followers:
|
|
|
- self.cond_notify_followers.notify_all()
|
|
|
return allreduce_group
|
|
|
|
|
|
async def leader_disband_group(self):
|
|
|
""" Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
|
|
|
assert self.lock_request_join_group.locked()
|
|
|
self.current_followers.clear() # this will cause rpc_join_group to kick all followers out
|
|
|
- async with self.cond_notify_followers:
|
|
|
- self.cond_notify_followers.notify_all()
|
|
|
|
|
|
|
|
|
class PotentialLeaders:
|
|
|
""" An utility class that searches for averagers that could become our leaders """
|
|
|
- def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration):
|
|
|
+
|
|
|
+ def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
|
|
|
+ target_group_size: Optional[int]):
|
|
|
self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
|
|
|
+ self.target_group_size = target_group_size
|
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
|
+ self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
|
self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
|
|
|
- self.max_assured_time = float('-inf')
|
|
|
+ self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
|
|
|
self.declared_expiration_time = float('inf')
|
|
|
self.declared_group_key: Optional[GroupKey] = None
|
|
|
+ self.max_assured_time = float('-inf')
|
|
|
self.search_end_time = float('inf')
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
|
|
|
- assert not self.running.is_set(), "already running"
|
|
|
- self.running.set()
|
|
|
- self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
|
|
|
- update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
|
|
|
- declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
|
|
|
- try:
|
|
|
- yield self
|
|
|
- finally:
|
|
|
- update_queue_task.cancel()
|
|
|
- declare_averager_task.cancel()
|
|
|
- self.running.clear()
|
|
|
- self.update_triggered.clear()
|
|
|
- self.update_finished.clear()
|
|
|
+ async with self.lock_search:
|
|
|
+ self.running.set()
|
|
|
+ self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
|
|
|
+ update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
|
|
|
+ declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
|
|
|
+ try:
|
|
|
+ yield self
|
|
|
+ finally:
|
|
|
+ if not update_queue_task.done():
|
|
|
+ update_queue_task.cancel()
|
|
|
+ if not declare_averager_task.done():
|
|
|
+ declare_averager_task.cancel()
|
|
|
+ for field in (self.past_attempts, self.leader_queue, self.running,
|
|
|
+ self.update_finished, self.update_triggered, self.declared_expiration):
|
|
|
+ field.clear()
|
|
|
+ self.max_assured_time = float('-inf')
|
|
|
+ self.search_end_time = float('inf')
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def pause_search(self):
|
|
@@ -332,19 +372,34 @@ class PotentialLeaders:
|
|
|
|
|
|
async def pop_next_leader(self) -> Endpoint:
|
|
|
""" Remove and return the next most suitable leader or throw an exception if reached timeout """
|
|
|
- assert self.running, "Not running search at the moment"
|
|
|
- maybe_next_leader, entry = self.leader_queue.top()
|
|
|
-
|
|
|
- next_entry_time = entry.expiration_time if maybe_next_leader is not None else get_dht_time()
|
|
|
- if self.max_assured_time < next_entry_time < self.search_end_time:
|
|
|
- self.update_triggered.set()
|
|
|
+ assert self.running.is_set(), "Not running search at the moment"
|
|
|
+ while True:
|
|
|
+ maybe_next_leader, entry = self.leader_queue.top()
|
|
|
+
|
|
|
+ if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
|
|
|
+ self.update_triggered.set()
|
|
|
+
|
|
|
+ if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
|
|
|
+ await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
|
|
|
+ return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ self.declared_expiration.clear()
|
|
|
+ if self.update_finished.is_set():
|
|
|
+ self.update_finished.clear()
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ raise asyncio.TimeoutError("pop_next_leader was invalidated: re-declared averager in background")
|
|
|
|
|
|
- if maybe_next_leader is None:
|
|
|
- await self.update_finished.wait()
|
|
|
- return await self.pop_next_leader()
|
|
|
+ del self.leader_queue[maybe_next_leader]
|
|
|
+ self.past_attempts.add((maybe_next_leader, entry.expiration_time))
|
|
|
+ return maybe_next_leader
|
|
|
|
|
|
- del self.leader_queue[maybe_next_leader]
|
|
|
- return maybe_next_leader
|
|
|
+ @property
|
|
|
+ def request_expiration_time(self) -> float:
|
|
|
+ """ this averager's current expiration time - used to send join requests to leaders """
|
|
|
+ if isfinite(self.declared_expiration_time):
|
|
|
+ return self.declared_expiration_time
|
|
|
+ else:
|
|
|
+ return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
|
|
|
async def _update_queue_periodically(self, group_key: GroupKey):
|
|
|
DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
@@ -352,14 +407,14 @@ class PotentialLeaders:
|
|
|
new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
|
|
|
self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
|
|
|
|
|
|
+ self.leader_queue.clear()
|
|
|
for peer, peer_expiration_time in new_peers:
|
|
|
- if peer == self.endpoint:
|
|
|
+ if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
continue
|
|
|
self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
|
self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
|
|
|
|
- if len(self.leader_queue) > 0:
|
|
|
- self.update_finished.set()
|
|
|
+ self.update_finished.set()
|
|
|
|
|
|
await asyncio.wait(
|
|
|
{self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
|
|
@@ -367,28 +422,31 @@ class PotentialLeaders:
|
|
|
self.update_triggered.clear()
|
|
|
|
|
|
async def _declare_averager_periodically(self, group_key: GroupKey):
|
|
|
- try:
|
|
|
- while True:
|
|
|
- new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
- self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
|
|
|
- stored_ok = await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
|
|
|
- looking_for_group=True, return_future=True)
|
|
|
- if stored_ok:
|
|
|
+ async with self.lock_declare:
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ await self.running.wait()
|
|
|
+
|
|
|
+ new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
+ self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
|
|
|
+ self.declared_expiration.set()
|
|
|
+ await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
|
|
|
+ looking_for_group=True, return_future=True)
|
|
|
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
|
|
|
- else:
|
|
|
- logger.warning(f"Failed to subscribe to group {group_key} : store rejected by DHT peers")
|
|
|
- finally:
|
|
|
- if self.declared_group_key is not None:
|
|
|
- previous_declared_key, previous_expiration_time = self.declared_group_key, self.declared_expiration_time
|
|
|
- self.declared_group_key, self.declared_expiration_time = None, float('inf')
|
|
|
- self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
|
|
|
- await self.dht.declare_averager(previous_declared_key, self.endpoint, previous_expiration_time,
|
|
|
- looking_for_group=False, return_future=True)
|
|
|
+ except Exception as e: # note: we catch exceptions here because otherwise they are never printed
|
|
|
+ logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
|
|
|
+ finally:
|
|
|
+ if self.declared_group_key is not None:
|
|
|
+ prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
|
|
|
+ self.declared_group_key, self.declared_expiration_time = None, float('inf')
|
|
|
+ self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
|
|
|
+ await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
|
|
|
+ looking_for_group=False, return_future=True)
|
|
|
|
|
|
|
|
|
def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
|
|
|
""" A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
|
|
|
schema_dicts = [{field_name: str(field_value)
|
|
|
- for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
|
|
|
+ for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
|
|
|
for tensor in tensors]
|
|
|
return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
|