justheuristic преди 3 години
родител
ревизия
a66ff41635
променени са 1 файла, в които са добавени 1 реда и са изтрити 3 реда
  1. 1 3
      src/server/handler.py

+ 1 - 3
src/server/handler.py

@@ -272,13 +272,11 @@ async def _rpc_backward(
             inputs[:, :pre_seq_len] += prompt
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
-        inter_inputs.append(inputs.clone())  # TODO optimize: reduce the number of copies
+        inter_inputs.append(inputs)
 
     grad_prompts = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
-        if not is_dummy(prompt):
-            inp[:, :pre_seq_len] += prompt
         (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):