瀏覽代碼

Move inference/fwd/bwd outputs to the same devices and dtypes as inputs

Aleksandr Borzunov 2 年之前
父節點
當前提交
6190a5909e
共有 2 個文件被更改,包括 33 次插入5 次删除
  1. 9 2
      src/client/inference_session.py
  2. 24 3
      src/client/sequential_autograd.py

+ 9 - 2
src/client/inference_session.py

@@ -218,6 +218,11 @@ class InferenceSession:
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
 
+        inputs_device = inputs.device
+        inputs_dtype = inputs.dtype
+        inputs = inputs.cpu()
+        prompts = prompts.cpu()
+
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
             raise ValueError(
@@ -300,12 +305,14 @@ class InferenceSession:
                         f"Caught exception when running inference from block {block_idx} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
-                    traceback_level = logging.DEBUG if e.message else logging.WARNING
+                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
                     logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
 
         self._position += n_input_tokens
-        return inputs
+
+        outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+        return outputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""

+ 24 - 3
src/client/sequential_autograd.py

@@ -37,6 +37,11 @@ async def sequential_forward(
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
+    inputs_device = inputs.device
+    inputs_dtype = inputs.dtype
+    inputs = inputs.cpu()
+    prompts = prompts.cpu()
+
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert is_dummy(prompts) or len(prompts) == len(
@@ -87,10 +92,12 @@ async def sequential_forward(
                     f"Caught exception when running forward from block {block_idx} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                traceback_level = logging.DEBUG if e.message else logging.WARNING
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
                 logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
+    outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+    intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
     return outputs, intermediate_inputs, done_sequences
 
 
@@ -100,13 +107,22 @@ async def sequential_backward(
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
-) -> Sequence[torch.Tensor]:
+) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
     assert len(intermediate_inputs) == len(forward_sequences)
 
+    grad_outputs_device = grad_outputs[0].device if grad_outputs else None
+    grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
+    prompts_device = prompts.device
+    prompts_dtype = prompts.dtype
+
+    grad_outputs = [tensor.cpu() for tensor in grad_outputs]
+    intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
+    prompts = prompts.cpu()
+
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         inputs = intermediate_inputs.pop()
@@ -148,13 +164,18 @@ async def sequential_backward(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                traceback_level = logging.DEBUG if e.message else logging.WARNING
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
                 logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+
+    if grad_outputs_dtype is not None:
+        grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
+    if grad_prompts is not None:
+        grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
     return grad_outputs, grad_prompts