Browse Source

reuse existing variable

justheuristic 3 years ago
parent
commit
cb52c3ff4f
1 changed files with 3 additions and 8 deletions
  1. 3 8
      src/server/handler.py

+ 3 - 8
src/server/handler.py

@@ -251,7 +251,6 @@ async def _rpc_backward(
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-    NO_PROMPTS = not prompts
     prompts = prompts.to(requested_backends[0].dtype) if prompts else DUMMY
 
     if is_dummy(prompts):
@@ -266,11 +265,10 @@ async def _rpc_backward(
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs = inputs.clone()  # TODO
             inputs[:, :pre_seq_len] += prompt
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
-        inter_inputs.append(inputs)
+        inter_inputs.append(inputs.clone())  #TODO optimize: reduce the number of copies
 
     grad_prompts = []
     # Run a chain of requested backends
@@ -281,9 +279,6 @@ async def _rpc_backward(
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
-        else:
-            grad_prompts.append(DUMMY)
 
-    is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
-    grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
-    return [grad_outputs] if NO_PROMPTS else [grad_outputs, grad_prompts]  # TODO un-duct-tape
+    grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape