|
@@ -109,7 +109,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
|
|
|
_matchmaking: Matchmaking
|
|
|
- _pending_group_assembled: asyncio.Event
|
|
|
+ _pending_groups_registered: asyncio.Event
|
|
|
_state_updated: asyncio.Event
|
|
|
_p2p: P2P
|
|
|
serializer = MSGPackSerializer
|
|
@@ -207,7 +207,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
reducer_timeout=reducer_timeout,
|
|
|
)
|
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
|
- self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
+ self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
|
|
|
|
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
|
|
|
@@ -309,8 +309,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
self._state_updated = asyncio.Event()
|
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
- self._pending_group_assembled.set()
|
|
|
+ self._pending_groups_registered = asyncio.Event()
|
|
|
+ self._pending_groups_registered.set()
|
|
|
except Exception as e:
|
|
|
# Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
logger.debug(e, exc_info=True)
|
|
@@ -441,7 +441,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
while not step.done():
|
|
|
try:
|
|
|
- self._pending_group_assembled.clear()
|
|
|
+ self._pending_groups_registered.clear()
|
|
|
step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
|
|
|
check_cancel_task = asyncio.create_task(step.wait_for_cancel())
|
|
@@ -455,20 +455,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
group_info = await matchmaking_task
|
|
|
|
|
|
- if group_info is None:
|
|
|
- raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
+ async with self._register_allreduce_group(group_info):
|
|
|
+ if group_info is None:
|
|
|
+ raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
|
- step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
+ step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|
|
|
- step.set_result(
|
|
|
- await asyncio.wait_for(
|
|
|
- self._run_allreduce(
|
|
|
- group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
- ),
|
|
|
- timeout=self._allreduce_timeout,
|
|
|
+ step.set_result(
|
|
|
+ await asyncio.wait_for(
|
|
|
+ self._run_allreduce(
|
|
|
+ group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
+ ),
|
|
|
+ timeout=self._allreduce_timeout,
|
|
|
+ )
|
|
|
)
|
|
|
- )
|
|
|
- # averaging is finished, loop will now exit
|
|
|
+ # averaging is finished, loop will now exit
|
|
|
|
|
|
except (
|
|
|
AllreduceException,
|
|
@@ -503,6 +504,19 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
+ @contextlib.contextmanager
|
|
|
+ def _register_allreduce_group(self, group_id: GroupID):
|
|
|
+ """registers a given all-reduce runner to listen for incoming connections"""
|
|
|
+ try:
|
|
|
+ self._running_groups[group_id] = asyncio.Future()
|
|
|
+ self._pending_groups_registered.set()
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ maybe_future = self._running_groups.pop(group_id, None)
|
|
|
+ if maybe_future and not maybe_future.done():
|
|
|
+ logger.warning(f"All-reduce group {group_id} did not finish.")
|
|
|
+ self._pending_groups_registered.set()
|
|
|
+
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
try:
|
|
@@ -531,18 +545,19 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
modes=modes,
|
|
|
**kwargs,
|
|
|
)
|
|
|
+ self._running_groups[group_info.group_id].set_result(allreduce)
|
|
|
+ # ^--- maybe this can be extracted into a method that checks if register_... context is active.
|
|
|
|
|
|
- with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
- if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- iter_results = allreduce.run()
|
|
|
- async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
|
|
|
- # all-reduce is performed asynchronously while iterating
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self._state_updated.set()
|
|
|
+ if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
+ iter_results = allreduce.run()
|
|
|
+ async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
|
|
|
+ # all-reduce is performed asynchronously while iterating
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
+ self._state_updated.set()
|
|
|
|
|
|
- else:
|
|
|
- async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
- raise ValueError("aux peers should not receive averaged tensors")
|
|
|
+ else:
|
|
|
+ async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
return allreduce.gathered
|
|
|
except BaseException as e:
|
|
@@ -550,17 +565,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
logger.exception(e)
|
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
|
|
- @contextlib.contextmanager
|
|
|
- def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
|
|
|
- """registers a given all-reduce runner to listen for incoming connections"""
|
|
|
- try:
|
|
|
- self._running_groups[group_id] = allreduce
|
|
|
- self._pending_group_assembled.set()
|
|
|
- yield
|
|
|
- finally:
|
|
|
- self._running_groups.pop(group_id, None)
|
|
|
- self._pending_group_assembled.set()
|
|
|
-
|
|
|
@contextlib.contextmanager
|
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
|
"""
|
|
@@ -586,13 +590,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if request.group_id not in self._running_groups:
|
|
|
# this handles a special case when leader accepted us to group AND began allreduce right away,
|
|
|
# but his response with group_id was delayed and other peers got to us first
|
|
|
- await self._pending_group_assembled.wait()
|
|
|
+ await self._pending_groups_registered.wait()
|
|
|
|
|
|
- group = self._running_groups.get(request.group_id)
|
|
|
- if group is None:
|
|
|
+ future = self._running_groups.get(request.group_id)
|
|
|
+ if future is None:
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
return
|
|
|
|
|
|
+ group = await future
|
|
|
async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
|
|
|
yield message
|
|
|
|