Ver código fonte

actually fix tests

justheuristic 3 anos atrás
pai
commit
b3b3264e13
1 arquivos alterados com 3 adições e 3 exclusões
  1. 3 3
      src/client/sequential_autograd.py

+ 3 - 3
src/client/sequential_autograd.py

@@ -173,11 +173,11 @@ async def sequential_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
-                grad_outputs, span_grad_prompts = await run_expert_backward(
+                grad_outputs, *span_grad_prompts = await run_expert_backward(
                     span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts
                 )
                 grad_outputs = [grad_outputs]
-                grad_prompts.append(span_grad_prompts)
+                grad_prompts.extend(span_grad_prompts)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
@@ -194,7 +194,7 @@ async def sequential_backward(
     dummy_grad_prompts = [is_dummy(grad_prompt) for grad_prompt in grad_prompts]
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
-    grad_prompts = torch.cat(grad_prompts, dim=0) if not any(dummy_grad_prompts) else None
+    grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else None
     return grad_outputs, grad_prompts