|
@@ -17,7 +17,7 @@ from hivemind.client.averaging.allreduce import AllReduceRunner
|
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
|
|
|
from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
|
|
|
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
|
|
|
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, timed_storage, TimedStorage
|
|
|
from hivemind.proto import averaging_pb2, averaging_pb2_grpc
|
|
|
from hivemind.utils.grpc import ChannelCache
|
|
|
|
|
@@ -29,19 +29,19 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
f"""
|
|
|
An internal class that is used to form groups of averages for running allreduce
|
|
|
See DecentralizedAverager docstring for the detailed description of all parameters
|
|
|
-
|
|
|
+
|
|
|
:note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
|
|
|
This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
|
|
|
In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
|
|
|
- This deadlock only happens if averagers have outdated information on expirations (due to network delays).
|
|
|
+ This deadlock only happens if averagers have outdated information on expirations (due to network delays).
|
|
|
While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
|
|
|
Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
|
|
|
- prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
|
|
|
- averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
|
|
|
- min_vector_size: int, **allreduce_kwargs):
|
|
|
+ prefix: str, target_group_size: int, min_group_size: int, min_vector_size: int,
|
|
|
+ request_timeout: float, client_mode: bool, initial_group_bits: Optional[str] = None,
|
|
|
+ averaging_expiration: float = 15, throughput: Optional[float] = None, **allreduce_kwargs):
|
|
|
assert '.' not in prefix, "group prefix must be a string without ."
|
|
|
if request_timeout is None or request_timeout >= averaging_expiration:
|
|
|
logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
|
|
@@ -52,6 +52,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
|
|
|
self.target_group_size, self.min_group_size = target_group_size, min_group_size
|
|
|
self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
|
|
|
+ self.client_mode = client_mode
|
|
|
self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
|
|
|
self.schema_hash = compute_schema_hash(self.averaged_tensors)
|
|
|
self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
|
|
@@ -80,7 +81,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
lfg_status += f" leading {len(self.current_followers)} followers,"
|
|
|
schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
|
|
|
return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
|
|
|
- f" current key = {self.group_key_manager.current_key})"
|
|
|
+ f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
|
|
|
|
|
|
async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
|
|
|
) -> Optional[AllReduceRunner]:
|
|
@@ -124,7 +125,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
|
|
|
""" Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
|
|
|
- async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
|
|
|
+ async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
|
|
|
while True:
|
|
|
try:
|
|
|
next_leader = await self.potential_leaders.pop_next_leader() # throws TimeoutError on expiration
|
|
@@ -166,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
|
|
|
endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
|
|
|
throughput=self.throughput if self.throughput is not None else -1.0,
|
|
|
- gather=self.data_for_gather))
|
|
|
+ client_mode=self.client_mode, gather=self.data_for_gather))
|
|
|
message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
|
|
|
|
|
|
if message.code == averaging_pb2.ACCEPTED:
|
|
@@ -276,7 +277,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
|
|
|
or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
|
|
|
- or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0:
|
|
|
+ or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode:
|
|
|
return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
|
|
|
|
|
|
elif request.schema_hash != self.schema_hash:
|
|
@@ -297,24 +298,26 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
async def leader_assemble_group(self) -> AllReduceRunner:
|
|
|
""" Form up all current followers into a group and prepare to _run_allreduce """
|
|
|
- assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
|
|
|
+ assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
|
|
|
assert not self.assembled_group.done()
|
|
|
group_id = DHTID.generate().to_bytes()
|
|
|
ordered_group_endpoints = list(self.current_followers)
|
|
|
ordered_group_endpoints.append(self.endpoint)
|
|
|
random.shuffle(ordered_group_endpoints)
|
|
|
|
|
|
- throughputs, gathered = [], []
|
|
|
+ averager_throughputs, gathered = [], []
|
|
|
for endpoint in ordered_group_endpoints:
|
|
|
if endpoint == self.endpoint:
|
|
|
- throughputs.append(self.throughput)
|
|
|
+ averager_throughputs.append(self.throughput)
|
|
|
gathered.append(self.data_for_gather)
|
|
|
else:
|
|
|
follower_info = self.current_followers[endpoint]
|
|
|
- throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
|
|
|
+ throughput = follower_info.throughput if follower_info.throughput >= 0 else None
|
|
|
+ averager_throughput = throughput if not follower_info.client_mode else 0.0
|
|
|
+ averager_throughputs.append(averager_throughput)
|
|
|
gathered.append(follower_info.gather if follower_info.gather else None)
|
|
|
|
|
|
- part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
|
|
|
+ part_sizes = load_balance_peers(self.total_size, averager_throughputs, self.min_vector_size)
|
|
|
group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
|
|
|
|
|
|
logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
|
|
@@ -331,13 +334,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
assert not self.assembled_group.done()
|
|
|
assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
|
|
|
|
|
|
- group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
|
|
|
+ group_id, ordered_group_endpoints, part_sizes = msg.group_id, tuple(msg.ordered_group_endpoints), msg.part_sizes
|
|
|
assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
|
|
|
assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
|
|
|
+ my_part_size = part_sizes[ordered_group_endpoints.index(self.endpoint)]
|
|
|
+ assert my_part_size == 0 or not self.client_mode, "Averager with client_mode=True cannot accept incoming data."
|
|
|
|
|
|
logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
|
|
|
allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
|
|
|
- ordered_group_endpoints=tuple(ordered_group_endpoints),
|
|
|
+ ordered_group_endpoints=ordered_group_endpoints,
|
|
|
part_sizes=tuple(part_sizes), gathered=msg.gathered,
|
|
|
group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
|
|
|
await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
|
|
@@ -346,7 +351,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
|
|
|
async def leader_disband_group(self):
|
|
|
""" Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
|
|
|
- assert self.lock_request_join_group.locked()
|
|
|
+ assert self.lock_request_join_group.locked() and not self.client_mode
|
|
|
self.current_followers.clear() # this will cause rpc_join_group to kick all followers out
|
|
|
|
|
|
|
|
@@ -366,19 +371,22 @@ class PotentialLeaders:
|
|
|
self.search_end_time = float('inf')
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
- async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
|
|
|
+ async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
|
|
|
async with self.lock_search:
|
|
|
self.running.set()
|
|
|
self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
|
|
|
update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
|
|
|
- declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
|
|
|
+ if declare:
|
|
|
+ declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
|
|
|
+
|
|
|
try:
|
|
|
yield self
|
|
|
finally:
|
|
|
if not update_queue_task.done():
|
|
|
update_queue_task.cancel()
|
|
|
- if not declare_averager_task.done():
|
|
|
+ if declare and not declare_averager_task.done():
|
|
|
declare_averager_task.cancel()
|
|
|
+
|
|
|
for field in (self.past_attempts, self.leader_queue, self.running,
|
|
|
self.update_finished, self.update_triggered, self.declared_expiration):
|
|
|
field.clear()
|