Procházet zdrojové kódy

Show tracebacks in case of empty error messages

Aleksandr Borzunov před 2 roky
rodič
revize
e37b2f526a

+ 3 - 2
src/client/inference_session.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import itertools
+import logging
 import time
 from typing import AsyncIterator, List, Optional
 
@@ -18,7 +19,6 @@ from hivemind import (
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import aiter_with_timeout
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@@ -300,7 +300,8 @@ class InferenceSession:
                         f"Caught exception when running inference from block {block_idx} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
-                    logger.debug("See detailed traceback below:", exc_info=True)
+                    traceback_level = logging.DEBUG if e.message else logging.WARNING
+                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
 
         self._position += n_input_tokens

+ 5 - 2
src/client/sequential_autograd.py

@@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 import asyncio
 import itertools
+import logging
 from collections import deque
 from typing import List, Optional, Sequence, Tuple
 
@@ -86,7 +87,8 @@ async def sequential_forward(
                     f"Caught exception when running forward from block {block_idx} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if e.message else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
     return outputs, intermediate_inputs, done_sequences
@@ -146,7 +148,8 @@ async def sequential_backward(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if e.message else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts