|
@@ -227,8 +227,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
assert hidden_states.ndim == 3
|
|
|
if not prompts or len(prompts) == 1 and is_dummy(prompts[0]):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
+ pre_seq_len = 0
|
|
|
else:
|
|
|
- pre_seq_len = prompts.shape[2]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ pre_seq_len = prompts[0].shape[-2]
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
@@ -255,9 +257,10 @@ async def _rpc_backward(
|
|
|
|
|
|
if is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
+ pre_seq_len = 0
|
|
|
else:
|
|
|
- pre_seq_len = prompts.shape[2]
|
|
|
- prompts = [p.squeeze(0) for p in prompts.split(1)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ pre_seq_len = prompts[0].shape[-2]
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|