|
@@ -6,20 +6,20 @@ import contextlib
|
|
|
import random
|
|
|
from dataclasses import asdict
|
|
|
from math import isfinite
|
|
|
-from typing import Sequence, Optional, AsyncIterator, Set, Tuple
|
|
|
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
|
|
|
import asyncio
|
|
|
|
|
|
-import torch
|
|
|
import grpc
|
|
|
+import torch
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.client.averaging.allreduce import AllReduceRunner, GroupID
|
|
|
+from hivemind.client.averaging.allreduce import AllReduceRunner
|
|
|
+from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
|
|
|
from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
|
|
|
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
|
|
|
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc
|
|
|
from hivemind.utils.grpc import ChannelCache
|
|
|
|
|
|
-
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
@@ -34,12 +34,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
This deadlock only happens if averagers have outdated information on expirations (due to network delays).
|
|
|
While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
|
|
|
Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
|
|
|
-
|
|
|
"""
|
|
|
|
|
|
def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
|
|
|
prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
|
|
|
- averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
|
|
|
+ averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
|
|
|
+ min_vector_size: int, **allreduce_kwargs):
|
|
|
assert '.' not in prefix, "group prefix must be a string without ."
|
|
|
if request_timeout is None or request_timeout >= averaging_expiration:
|
|
|
logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
|
|
@@ -50,8 +50,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
self.prefix, self.group_bits = prefix, initial_group_bits
|
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
|
+ self.throughput, self.min_vector_size = throughput, min_vector_size
|
|
|
self.allreduce_kwargs = allreduce_kwargs
|
|
|
self.schema_hash = compute_schema_hash(self.averaged_tensors)
|
|
|
+ self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
|
|
|
|
|
|
self.lock_looking_for_group = asyncio.Lock()
|
|
|
self.lock_request_join_group = asyncio.Lock()
|
|
@@ -60,8 +62,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
self.assembled_group = asyncio.Future()
|
|
|
|
|
|
self.current_leader: Optional[Endpoint] = None # iff i am a follower, this is a link to my current leader
|
|
|
- self.current_followers: Set[Endpoint] = set() # iff i am a leader, this contains my followers excluding myself
|
|
|
+ self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {} # my current followers excluding myself
|
|
|
self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
|
|
|
+ self.data_for_gather: bytes = None
|
|
|
|
|
|
@property
|
|
|
def is_looking_for_group(self):
|
|
@@ -82,8 +85,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
|
|
|
f" current key = {self.current_group_key})"
|
|
|
|
|
|
- async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
|
|
|
+ async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
|
|
|
+ ) -> Optional[AllReduceRunner]:
|
|
|
"""
|
|
|
+ :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
|
|
|
+ :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
|
|
|
:returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
|
|
|
Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
|
|
|
"""
|
|
@@ -91,6 +97,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
|
|
|
" the existing group is either assembled or disbanded.")
|
|
|
async with self.lock_looking_for_group:
|
|
|
+ self.data_for_gather = data_for_gather
|
|
|
request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
|
|
|
try:
|
|
|
return await asyncio.wait_for(self.assembled_group, timeout=timeout)
|
|
@@ -116,6 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
# note: the code above ensures that we send all followers away before creating new future
|
|
|
self.assembled_group = asyncio.Future()
|
|
|
self.was_accepted_to_group.clear()
|
|
|
+ self.data_for_gather = None
|
|
|
|
|
|
async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
|
|
|
""" Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
|
|
@@ -161,7 +169,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
async with self.lock_request_join_group:
|
|
|
leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
|
|
|
- endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
|
|
|
+ endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
|
|
|
+ throughput=self.throughput if self.throughput is not None else -1.0,
|
|
|
+ gather=self.data_for_gather))
|
|
|
message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
@@ -182,8 +192,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
if message.code == averaging_pb2.BEGIN_ALLREDUCE:
|
|
|
async with self.lock_request_join_group:
|
|
|
- return await self.follower_assemble_group(
|
|
|
- leader, message.group_id, message.ordered_group_endpoints)
|
|
|
+ return await self.follower_assemble_group(leader, message)
|
|
|
|
|
|
if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
|
|
|
if message.suggested_leader and message.suggested_leader != self.endpoint:
|
|
@@ -218,7 +227,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
yield reason_to_reject
|
|
|
return
|
|
|
|
|
|
- self.current_followers.add(request.endpoint)
|
|
|
+ self.current_followers[request.endpoint] = request
|
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
|
|
|
|
|
|
if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
|
|
@@ -253,14 +262,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
allreduce_group = self.assembled_group.result()
|
|
|
yield averaging_pb2.MessageFromLeader(
|
|
|
code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
|
|
|
- ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
|
|
|
+ ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
|
|
|
+ part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
|
|
|
finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
|
|
|
- self.current_followers.discard(request.endpoint)
|
|
|
+ self.current_followers.pop(request.endpoint, None)
|
|
|
self.follower_was_discarded.set()
|
|
|
|
|
|
def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
|
|
@@ -297,22 +307,40 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
ordered_group_endpoints = list(self.current_followers)
|
|
|
ordered_group_endpoints.append(self.endpoint)
|
|
|
random.shuffle(ordered_group_endpoints)
|
|
|
+
|
|
|
+ throughputs, gathered = [], []
|
|
|
+ for endpoint in ordered_group_endpoints:
|
|
|
+ if endpoint == self.endpoint:
|
|
|
+ throughputs.append(self.throughput)
|
|
|
+ gathered.append(self.data_for_gather)
|
|
|
+ else:
|
|
|
+ follower_info = self.current_followers[endpoint]
|
|
|
+ throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
|
|
|
+ gathered.append(follower_info.gather if follower_info.gather else None)
|
|
|
+
|
|
|
+ part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
|
|
|
+
|
|
|
logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
|
|
|
allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
- ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
|
|
|
+ ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
|
|
|
+ gathered=gathered, **self.allreduce_kwargs)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
return allreduce_group
|
|
|
|
|
|
- async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
|
|
|
- ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
|
|
|
+ async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> AllReduceRunner:
|
|
|
""" Prepare to run allreduce using a list of peers provided by our leader """
|
|
|
assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
|
assert not self.assembled_group.done()
|
|
|
- logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
|
|
|
assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
|
|
|
+
|
|
|
+ group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
|
|
|
assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
|
|
|
+ assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
|
|
|
+
|
|
|
+ logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
|
|
|
allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
- ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
|
|
|
+ ordered_group_endpoints=tuple(ordered_group_endpoints),
|
|
|
+ part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
|
|
|
self.assembled_group.set_result(allreduce_group)
|
|
|
return allreduce_group
|
|
|
|