瀏覽代碼

timeouts-and-debugprint

justheuristic 3 年之前
父節點
當前提交
970cde890c
共有 1 個文件被更改,包括 15 次插入5 次删除
  1. 15 5
      hivemind/averaging/allreduce.py

+ 15 - 5
hivemind/averaging/allreduce.py

@@ -215,10 +215,20 @@ class AllReduceRunner(ServicerBase):
             try:
                 done_sending = asyncio.Event()
                 inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
-                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+                stream = await asyncio.wait_for(
+                    self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter), self.reducer_timeout
+                )
 
                 if self.should_delay_results(self.peer_id):
-                    await done_sending.wait()
+                    # TODO if inputs_aiter fails, we must ensure that done_sending is set anyway!
+                    # Otherwise, there is a risk that we will sleep here forever.
+                    try:
+                        logger.warning("WAITING FOR DONE_SENDING!!!")
+                        await asyncio.wait_for(done_sending.wait(), 300)
+                    except asyncio.TimeoutError:
+                        logger.error("Not done sending 300s!!! (will keep waiting)")
+                        await done_sending.wait()
+                        raise
 
                 part_index = 0
 
@@ -227,11 +237,11 @@ class AllReduceRunner(ServicerBase):
                         raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
                     return deserialize_torch_tensor(msg.tensor_part), msg
 
-                async for delta, msg in amap_in_executor(
+                async for delta, msg in aiter_with_timeout(amap_in_executor(
                     _try_deserialize,
-                    aiter_with_timeout(stream, self.reducer_timeout),
+                    stream,
                     max_prefetch=self.tensor_part_container.prefetch,
-                ):
+                ), self.reducer_timeout):
                     self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
                     part_index += 1