瀏覽代碼

hotfix for grads

justheuristic 3 年之前
父節點
當前提交
a66ff41635
共有 1 個文件被更改,包括 1 次插入3 次删除
  1. 1 3
      src/server/handler.py

+ 1 - 3
src/server/handler.py

@@ -272,13 +272,11 @@ async def _rpc_backward(
             inputs[:, :pre_seq_len] += prompt
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
-        inter_inputs.append(inputs.clone())  # TODO optimize: reduce the number of copies
+        inter_inputs.append(inputs)
 
     grad_prompts = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
-        if not is_dummy(prompt):
-            inp[:, :pre_seq_len] += prompt
         (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):