|
@@ -12,14 +12,15 @@ import asyncio
|
|
|
import grpc
|
|
|
import torch
|
|
|
|
|
|
-import hivemind
|
|
|
from hivemind.client.averaging.allreduce import AllReduceRunner
|
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
|
-from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
|
|
|
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
|
|
|
+from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
|
|
|
+from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
|
|
|
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
|
|
|
from hivemind.proto import averaging_pb2, averaging_pb2_grpc
|
|
|
from hivemind.utils.grpc import ChannelCache
|
|
|
|
|
|
+
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
@@ -36,7 +37,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
|
|
|
+ def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
|
|
|
prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
|
|
|
averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
|
|
|
min_vector_size: int, **allreduce_kwargs):
|
|
@@ -46,12 +47,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
"matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
|
|
|
|
|
|
super().__init__()
|
|
|
- self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
|
|
|
- self.prefix, self.group_bits = prefix, initial_group_bits
|
|
|
+ self.endpoint, self.averaged_tensors = endpoint, tuple(averaged_tensors)
|
|
|
+ self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
|
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
|
- self.throughput, self.min_vector_size = throughput, min_vector_size
|
|
|
- self.allreduce_kwargs = allreduce_kwargs
|
|
|
+ self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
|
|
|
self.schema_hash = compute_schema_hash(self.averaged_tensors)
|
|
|
self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
|
|
|
|
|
@@ -63,17 +63,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
|
self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
|
- self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
|
|
|
- self.data_for_gather: bytes = None
|
|
|
+ self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
|
|
|
+ self.data_for_gather: Optional[bytes] = None
|
|
|
|
|
|
@property
|
|
|
def is_looking_for_group(self):
|
|
|
return self.lock_looking_for_group.locked()
|
|
|
|
|
|
- @property
|
|
|
- def current_group_key(self) -> GroupKey:
|
|
|
- return f"{self.prefix}.0b{self.group_bits}"
|
|
|
-
|
|
|
def __repr__(self):
|
|
|
lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
|
|
|
if self.is_looking_for_group:
|
|
@@ -83,12 +79,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
lfg_status += f" leading {len(self.current_followers)} followers,"
|
|
|
schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
|
|
|
return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
|
|
|
- f" current key = {self.current_group_key})"
|
|
|
+ f" current key = {self.group_key_manager.current_key})"
|
|
|
|
|
|
async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
|
|
|
) -> Optional[AllReduceRunner]:
|
|
|
"""
|
|
|
- :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
|
|
|
+ :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
|
|
|
:param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
|
|
|
:returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
|
|
|
Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
|
|
@@ -127,10 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
|
|
|
""" Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
|
|
|
- async with self.potential_leaders.begin_search(self.current_group_key, timeout):
|
|
|
- # TODO update group_bits on success! reduce number of bits on not enough peers.
|
|
|
- # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
|
|
|
- # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
|
|
|
+ async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
|
|
|
while True:
|
|
|
try:
|
|
|
next_leader = await self.potential_leaders.pop_next_leader() # throws TimeoutError on expiration
|
|
@@ -148,7 +141,6 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
return await self.leader_assemble_group()
|
|
|
elif len(self.current_followers) > 0:
|
|
|
await self.leader_disband_group()
|
|
|
- # TODO maybe adjust grid size
|
|
|
continue
|
|
|
except Exception as e:
|
|
|
if not self.assembled_group.done():
|
|
@@ -262,8 +254,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
allreduce_group = self.assembled_group.result()
|
|
|
yield averaging_pb2.MessageFromLeader(
|
|
|
code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
|
|
|
- ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
|
|
|
- part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
|
|
|
+ ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
|
|
|
+ gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
@@ -319,11 +311,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
gathered.append(follower_info.gather if follower_info.gather else None)
|
|
|
|
|
|
part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
|
|
|
+ group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
|
|
|
|
|
|
logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
|
|
|
allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
|
|
|
- gathered=gathered, **self.allreduce_kwargs)
|
|
|
+ gathered=gathered, group_key_seed=group_key_seed, **self.allreduce_kwargs)
|
|
|
+ await self.group_key_manager.update_key_on_group_assembled(allreduce_group, is_leader=True)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
return allreduce_group
|
|
|
|
|
@@ -340,7 +334,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
|
|
|
allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
ordered_group_endpoints=tuple(ordered_group_endpoints),
|
|
|
- part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
|
|
|
+ part_sizes=tuple(part_sizes), gathered=msg.gathered,
|
|
|
+ group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
|
|
|
+ await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
return allreduce_group
|
|
|
|
|
@@ -353,9 +349,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
class PotentialLeaders:
|
|
|
""" An utility class that searches for averagers that could become our leaders """
|
|
|
|
|
|
- def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
|
|
|
- target_group_size: Optional[int]):
|
|
|
- self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
|
|
|
+ def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
|
|
|
+ self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
|
|
|
self.target_group_size = target_group_size
|
|
|
self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
|
|
|
self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
|
|
@@ -367,12 +362,12 @@ class PotentialLeaders:
|
|
|
self.search_end_time = float('inf')
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
- async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
|
|
|
+ async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
|
|
|
async with self.lock_search:
|
|
|
self.running.set()
|
|
|
self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
|
|
|
- update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
|
|
|
- declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
|
|
|
+ update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
|
|
|
+ declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
|
|
|
try:
|
|
|
yield self
|
|
|
finally:
|
|
@@ -429,38 +424,46 @@ class PotentialLeaders:
|
|
|
else:
|
|
|
return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
|
|
|
- async def _update_queue_periodically(self, group_key: GroupKey):
|
|
|
- DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
|
- while get_dht_time() < self.search_end_time:
|
|
|
- new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
|
|
|
- self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
|
|
|
-
|
|
|
- self.leader_queue.clear()
|
|
|
- for peer, peer_expiration_time in new_peers:
|
|
|
- if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
- continue
|
|
|
- self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
|
- self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
|
+ async def _update_queue_periodically(self, key_manager: GroupKeyManager):
|
|
|
+ try:
|
|
|
+ DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
|
+ while get_dht_time() < self.search_end_time:
|
|
|
+ new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
|
|
|
+ self.max_assured_time = max(self.max_assured_time,
|
|
|
+ get_dht_time() + self.averaging_expiration - DISCREPANCY)
|
|
|
+
|
|
|
+ self.leader_queue.clear()
|
|
|
+ for peer, peer_expiration_time in new_peers:
|
|
|
+ if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
+ continue
|
|
|
+ self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
|
+ self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
|
|
|
|
- self.update_finished.set()
|
|
|
+ self.update_finished.set()
|
|
|
|
|
|
- await asyncio.wait(
|
|
|
- {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
|
|
|
- timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
|
|
|
- self.update_triggered.clear()
|
|
|
+ await asyncio.wait(
|
|
|
+ {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
|
|
|
+ timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
|
|
|
+ self.update_triggered.clear()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
|
|
|
+ raise
|
|
|
|
|
|
- async def _declare_averager_periodically(self, group_key: GroupKey):
|
|
|
+ async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
|
|
|
async with self.lock_declare:
|
|
|
try:
|
|
|
while True:
|
|
|
await self.running.wait()
|
|
|
|
|
|
new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
- self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
|
|
|
+ self.declared_group_key = group_key = key_manager.current_key
|
|
|
+ self.declared_expiration_time = new_expiration_time
|
|
|
self.declared_expiration.set()
|
|
|
- await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
|
|
|
- looking_for_group=True, return_future=True)
|
|
|
+ await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
|
|
|
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
|
|
|
+ if self.running.is_set() and len(self.leader_queue) == 0:
|
|
|
+ await key_manager.update_key_on_not_enough_peers()
|
|
|
+
|
|
|
except Exception as e: # note: we catch exceptions here because otherwise they are never printed
|
|
|
logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
|
|
|
finally:
|
|
@@ -468,8 +471,8 @@ class PotentialLeaders:
|
|
|
prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
|
|
|
self.declared_group_key, self.declared_expiration_time = None, float('inf')
|
|
|
self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
|
|
|
- await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
|
|
|
- looking_for_group=False, return_future=True)
|
|
|
+ await key_manager.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
|
|
|
+ looking_for_group=False)
|
|
|
|
|
|
|
|
|
def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
|
|
@@ -478,3 +481,7 @@ def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
|
|
|
for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
|
|
|
for tensor in tensors]
|
|
|
return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
|
|
|
+
|
|
|
+
|
|
|
+class MatchmakingException(Exception):
|
|
|
+ """ An internal exception that marks undesired edge cases during averaging """
|