|
@@ -37,6 +37,11 @@ async def sequential_forward(
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
|
|
|
+ inputs_device = inputs.device
|
|
|
+ inputs_dtype = inputs.dtype
|
|
|
+ inputs = inputs.cpu()
|
|
|
+ prompts = prompts.cpu()
|
|
|
+
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
assert is_dummy(prompts) or len(prompts) == len(
|
|
@@ -87,10 +92,12 @@ async def sequential_forward(
|
|
|
f"Caught exception when running forward from block {block_idx} "
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
- traceback_level = logging.DEBUG if e.message else logging.WARNING
|
|
|
+ traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
await asyncio.sleep(delay)
|
|
|
|
|
|
+ outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
|
|
|
+ intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
|
|
|
return outputs, intermediate_inputs, done_sequences
|
|
|
|
|
|
|
|
@@ -100,13 +107,22 @@ async def sequential_backward(
|
|
|
prompts: torch.Tensor,
|
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
-) -> Sequence[torch.Tensor]:
|
|
|
+) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
|
|
|
"""
|
|
|
Performs chained backward for each forward subsequence.
|
|
|
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
|
|
|
"""
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
|
|
|
+ grad_outputs_device = grad_outputs[0].device if grad_outputs else None
|
|
|
+ grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
|
|
|
+ prompts_device = prompts.device
|
|
|
+ prompts_dtype = prompts.dtype
|
|
|
+
|
|
|
+ grad_outputs = [tensor.cpu() for tensor in grad_outputs]
|
|
|
+ intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
|
|
|
+ prompts = prompts.cpu()
|
|
|
+
|
|
|
grad_prompts_reversed = []
|
|
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
|
|
inputs = intermediate_inputs.pop()
|
|
@@ -148,13 +164,18 @@ async def sequential_backward(
|
|
|
f"Caught exception when running backward between blocks {span.start}-{span.end} "
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
- traceback_level = logging.DEBUG if e.message else logging.WARNING
|
|
|
+ traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
await asyncio.sleep(delay)
|
|
|
|
|
|
# For now, we do not support mixed dummy and grad prompts
|
|
|
# Concat in num_layer dimension
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
|
|
|
+
|
|
|
+ if grad_outputs_dtype is not None:
|
|
|
+ grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
|
|
|
+ if grad_prompts is not None:
|
|
|
+ grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
|
|
|
return grad_outputs, grad_prompts
|
|
|
|
|
|
|