|
@@ -225,24 +225,22 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
# check parse input tensors and cast dtypes
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
assert hidden_states.ndim == 3
|
|
|
- assert len(prompts) <= len(requested_backends), f"Expected at most {len(requested_backends)} prompts, one per layer"
|
|
|
-
|
|
|
- for i in range(len(prompts)):
|
|
|
- if not is_dummy(prompts[i]):
|
|
|
- assert prompts[i].ndim == 3, "prompts must have shape [batch or 1, seq_len or prefix, hidden_size]"
|
|
|
- prompts[i] = prompts[i].to(dtype)
|
|
|
- prompts.extend((DUMMY for _ in range(len(prompts), len(requested_backends)))) # add missing prompts
|
|
|
-
|
|
|
- seq_length = hidden_states.shape[1]
|
|
|
+ if not prompts or len(prompts) == 1 and is_dummy(prompts[0]):
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ else:
|
|
|
+ pre_seq_len = prompts.shape[2]
|
|
|
|
|
|
- # run forward pass for requested backends
|
|
|
+ # Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
if not is_dummy(prompt):
|
|
|
- hidden_states[:, : min(seq_length, prompt.shape[1]), ...] += prompt
|
|
|
+ hidden_states[:, :pre_seq_len] += prompt
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
- assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
|
|
|
+ assert (
|
|
|
+ hidden_states.ndim == 3
|
|
|
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
+ # Serialize the overall output
|
|
|
return hidden_states
|
|
|
|
|
|
|