justheuristic 3 vuotta sitten
vanhempi
commit
f54eabad16
1 muutettua tiedostoa jossa 11 lisäystä ja 13 poistoa
  1. 11 13
      src/server/handler.py

+ 11 - 13
src/server/handler.py

@@ -225,24 +225,22 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
-    assert len(prompts) <= len(requested_backends), f"Expected at most {len(requested_backends)} prompts, one per layer"
-
-    for i in range(len(prompts)):
-        if not is_dummy(prompts[i]):
-            assert prompts[i].ndim == 3, "prompts must have shape [batch or 1, seq_len or prefix, hidden_size]"
-            prompts[i] = prompts[i].to(dtype)
-    prompts.extend((DUMMY for _ in range(len(prompts), len(requested_backends))))  # add missing prompts
-
-    seq_length = hidden_states.shape[1]
+    if not prompts or len(prompts) == 1 and is_dummy(prompts[0]):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        pre_seq_len = prompts.shape[2]
 
-    # run forward pass for requested backends
+    # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
-        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         if not is_dummy(prompt):
-            hidden_states[:, : min(seq_length, prompt.shape[1]), ...] += prompt
+            hidden_states[:, :pre_seq_len] += prompt
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
-        assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
+    # Serialize the overall output
     return hidden_states