|
@@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
|
|
|
"""
|
|
|
import asyncio
|
|
|
import itertools
|
|
|
+import logging
|
|
|
from collections import deque
|
|
|
from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
@@ -86,7 +87,8 @@ async def sequential_forward(
|
|
|
f"Caught exception when running forward from block {block_idx} "
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
- logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
+ traceback_level = logging.DEBUG if e.message else logging.WARNING
|
|
|
+ logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
await asyncio.sleep(delay)
|
|
|
|
|
|
return outputs, intermediate_inputs, done_sequences
|
|
@@ -146,7 +148,8 @@ async def sequential_backward(
|
|
|
f"Caught exception when running backward between blocks {span.start}-{span.end} "
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
- logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
+ traceback_level = logging.DEBUG if e.message else logging.WARNING
|
|
|
+ logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
await asyncio.sleep(delay)
|
|
|
|
|
|
# For now, we do not support mixed dummy and grad prompts
|