|
@@ -191,6 +191,7 @@ async def sequential_backward(
|
|
|
|
|
|
# For now, we do not support mixed dummy and grad prompts
|
|
|
# Concat in num_layer dimension
|
|
|
+ assert not grad_prompts_reversed or len(grad_prompts_reversed) == len(prompts)
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
|
|
|
return grad_outputs, grad_prompts
|
|
|
|