|
@@ -215,10 +215,20 @@ class AllReduceRunner(ServicerBase):
|
|
|
try:
|
|
|
done_sending = asyncio.Event()
|
|
|
inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
|
|
|
- stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
|
|
|
+ stream = await asyncio.wait_for(
|
|
|
+ self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter), self.reducer_timeout
|
|
|
+ )
|
|
|
|
|
|
if self.should_delay_results(self.peer_id):
|
|
|
- await done_sending.wait()
|
|
|
+ # TODO if inputs_aiter fails, we must ensure that done_sending is set anyway!
|
|
|
+ # Otherwise, there is a risk that we will sleep here forever.
|
|
|
+ try:
|
|
|
+ logger.warning("WAITING FOR DONE_SENDING!!!")
|
|
|
+ await asyncio.wait_for(done_sending.wait(), 300)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ logger.error("Not done sending 300s!!! (will keep waiting)")
|
|
|
+ await done_sending.wait()
|
|
|
+ raise
|
|
|
|
|
|
part_index = 0
|
|
|
|
|
@@ -227,11 +237,11 @@ class AllReduceRunner(ServicerBase):
|
|
|
raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
|
|
|
return deserialize_torch_tensor(msg.tensor_part), msg
|
|
|
|
|
|
- async for delta, msg in amap_in_executor(
|
|
|
+ async for delta, msg in aiter_with_timeout(amap_in_executor(
|
|
|
_try_deserialize,
|
|
|
- aiter_with_timeout(stream, self.reducer_timeout),
|
|
|
+ stream,
|
|
|
max_prefetch=self.tensor_part_container.prefetch,
|
|
|
- ):
|
|
|
+ ), self.reducer_timeout):
|
|
|
self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
|
|
|
part_index += 1
|
|
|
|