dbaranchuk 3 lat temu
rodzic
commit
15e49f9026
1 zmienionych plików z 3 dodań i 1 usunięć
  1. 3 1
      src/server/handler.py

+ 3 - 1
src/server/handler.py

@@ -229,6 +229,8 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
 
     if is_dummy(prompts):
         prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.split(1)]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
@@ -241,7 +243,7 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
 
     grad_prompts = []
     # Run a chain of requested backends
-    for inp, prompt, backend in zip(inter_inputs[::-1], prompts.flip(0), requested_backends[::-1]):
+    for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
         grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
         assert isinstance(grads, (list, tuple)) and len(grads) == 2
         grad_outputs, grad_prompt = grads