|
@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
-from hivemind.compression import (
|
|
|
|
- CompressionBase,
|
|
|
|
- CompressionInfo,
|
|
|
|
- NoCompression,
|
|
|
|
- deserialize_torch_tensor,
|
|
|
|
- serialize_torch_tensor,
|
|
|
|
-)
|
|
|
|
|
|
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
|
|
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
|
|
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils.asyncio import (
|
|
from hivemind.utils.asyncio import (
|
|
achain,
|
|
achain,
|
|
- afirst,
|
|
|
|
aiter_with_timeout,
|
|
aiter_with_timeout,
|
|
anext,
|
|
anext,
|
|
as_aiter,
|
|
as_aiter,
|
|
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
"""
|
|
"""
|
|
|
|
|
|
_matchmaking: Matchmaking
|
|
_matchmaking: Matchmaking
|
|
- _pending_group_assembled: asyncio.Event
|
|
|
|
|
|
+ _pending_groups_registered: asyncio.Event
|
|
_state_updated: asyncio.Event
|
|
_state_updated: asyncio.Event
|
|
_p2p: P2P
|
|
_p2p: P2P
|
|
serializer = MSGPackSerializer
|
|
serializer = MSGPackSerializer
|
|
@@ -207,7 +200,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
reducer_timeout=reducer_timeout,
|
|
reducer_timeout=reducer_timeout,
|
|
)
|
|
)
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
- self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
|
|
|
+ self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
|
|
|
|
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
|
|
|
|
@@ -309,8 +302,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
self._state_updated = asyncio.Event()
|
|
self._state_updated = asyncio.Event()
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
|
|
+ self._pending_groups_registered = asyncio.Event()
|
|
|
|
+ self._pending_groups_registered.set()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
# Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
# Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
logger.debug(e, exc_info=True)
|
|
logger.debug(e, exc_info=True)
|
|
@@ -441,7 +434,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
while not step.done():
|
|
while not step.done():
|
|
try:
|
|
try:
|
|
- self._pending_group_assembled.clear()
|
|
|
|
|
|
+ self._pending_groups_registered.clear()
|
|
step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
|
|
matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
|
|
check_cancel_task = asyncio.create_task(step.wait_for_cancel())
|
|
check_cancel_task = asyncio.create_task(step.wait_for_cancel())
|
|
@@ -458,17 +451,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
if group_info is None:
|
|
if group_info is None:
|
|
raise AllreduceException("Averaging step failed: could not find a group")
|
|
raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
|
- step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|
-
|
|
|
|
- step.set_result(
|
|
|
|
- await asyncio.wait_for(
|
|
|
|
- self._run_allreduce(
|
|
|
|
- group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
|
- ),
|
|
|
|
- timeout=self._allreduce_timeout,
|
|
|
|
|
|
+ with self._register_allreduce_group(group_info):
|
|
|
|
+ step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|
+
|
|
|
|
+ step.set_result(
|
|
|
|
+ await asyncio.wait_for(
|
|
|
|
+ self._aggregate_with_group(
|
|
|
|
+ group_info,
|
|
|
|
+ tensor_infos=self.tensor_infos,
|
|
|
|
+ weight=step.weight,
|
|
|
|
+ **self.allreduce_kwargs,
|
|
|
|
+ ),
|
|
|
|
+ timeout=self._allreduce_timeout,
|
|
|
|
+ )
|
|
)
|
|
)
|
|
- )
|
|
|
|
- # averaging is finished, loop will now exit
|
|
|
|
|
|
+ # averaging is finished, loop will now exit
|
|
|
|
|
|
except (
|
|
except (
|
|
AllreduceException,
|
|
AllreduceException,
|
|
@@ -503,8 +500,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
- async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
|
- """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
|
|
|
+ @contextlib.contextmanager
|
|
|
|
+ def _register_allreduce_group(self, group_info: GroupInfo):
|
|
|
|
+ """Register a given group for one or more all-reduce rounds"""
|
|
|
|
+ try:
|
|
|
|
+ self._running_groups[group_info.group_id] = asyncio.Future()
|
|
|
|
+ self._pending_groups_registered.set()
|
|
|
|
+ yield
|
|
|
|
+ finally:
|
|
|
|
+ maybe_future = self._running_groups.pop(group_info.group_id, None)
|
|
|
|
+ if maybe_future is not None and not maybe_future.done():
|
|
|
|
+ logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
|
|
|
|
+ self._pending_groups_registered.set()
|
|
|
|
+
|
|
|
|
+ async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
|
+ """Run aggregation in a given group and update tensors in place, return gathered metadata"""
|
|
try:
|
|
try:
|
|
bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
@@ -519,47 +529,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
)
|
|
)
|
|
|
|
|
|
async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
- allreduce = AllReduceRunner(
|
|
|
|
- p2p=self._p2p,
|
|
|
|
- servicer_type=type(self),
|
|
|
|
- prefix=self.prefix,
|
|
|
|
- group_id=group_info.group_id,
|
|
|
|
- tensors=local_tensors,
|
|
|
|
- ordered_peer_ids=group_info.peer_ids,
|
|
|
|
- peer_fractions=peer_fractions,
|
|
|
|
- gathered=user_gathered,
|
|
|
|
- modes=modes,
|
|
|
|
- **kwargs,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
|
- if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
|
- iter_results = allreduce.run()
|
|
|
|
- async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
|
|
|
|
- # all-reduce is performed asynchronously while iterating
|
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
- self._state_updated.set()
|
|
|
|
-
|
|
|
|
- else:
|
|
|
|
- async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
|
- raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
-
|
|
|
|
- return allreduce.gathered
|
|
|
|
|
|
+ await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
|
|
|
|
+ return user_gathered
|
|
except BaseException as e:
|
|
except BaseException as e:
|
|
if isinstance(e, Exception):
|
|
if isinstance(e, Exception):
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
|
|
- @contextlib.contextmanager
|
|
|
|
- def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
|
|
|
|
- """registers a given all-reduce runner to listen for incoming connections"""
|
|
|
|
- try:
|
|
|
|
- self._running_groups[group_id] = allreduce
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
- yield
|
|
|
|
- finally:
|
|
|
|
- self._running_groups.pop(group_id, None)
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
|
|
+ async def _run_allreduce_inplace_(
|
|
|
|
+ self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
|
|
|
|
+ ):
|
|
|
|
+ """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
|
|
|
|
+ group_id = group_info.group_id if group_id is None else group_id
|
|
|
|
+
|
|
|
|
+ runner = AllReduceRunner(
|
|
|
|
+ p2p=self._p2p,
|
|
|
|
+ servicer_type=type(self),
|
|
|
|
+ prefix=self.prefix,
|
|
|
|
+ tensors=tensors,
|
|
|
|
+ group_id=group_id,
|
|
|
|
+ ordered_peer_ids=group_info.peer_ids,
|
|
|
|
+ **kwargs,
|
|
|
|
+ )
|
|
|
|
+ assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
|
|
|
|
+ self._running_groups[group_id].set_result(runner)
|
|
|
|
+
|
|
|
|
+ if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
|
+ async for tensor, update in azip(as_aiter(*tensors), runner):
|
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
|
+ self._state_updated.set()
|
|
|
|
+ else:
|
|
|
|
+ async for _ in runner:
|
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
@contextlib.contextmanager
|
|
@contextlib.contextmanager
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
@@ -586,13 +588,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
if request.group_id not in self._running_groups:
|
|
if request.group_id not in self._running_groups:
|
|
# this handles a special case when leader accepted us to group AND began allreduce right away,
|
|
# this handles a special case when leader accepted us to group AND began allreduce right away,
|
|
# but his response with group_id was delayed and other peers got to us first
|
|
# but his response with group_id was delayed and other peers got to us first
|
|
- await self._pending_group_assembled.wait()
|
|
|
|
|
|
+ await self._pending_groups_registered.wait()
|
|
|
|
|
|
- group = self._running_groups.get(request.group_id)
|
|
|
|
- if group is None:
|
|
|
|
|
|
+ future = self._running_groups.get(request.group_id)
|
|
|
|
+ if future is None:
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
return
|
|
return
|
|
|
|
|
|
|
|
+ group = await future
|
|
async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
|
|
async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
|
|
yield message
|
|
yield message
|
|
|
|
|