فهرست منبع

hotfix for grads

justheuristic 3 سال پیش
والد
کامیت
7eb4ca3cf6
1فایلهای تغییر یافته به همراه6 افزوده شده و 3 حذف شده
  1. 6 3
      src/server/handler.py

+ 6 - 3
src/server/handler.py

@@ -227,8 +227,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     assert hidden_states.ndim == 3
     if not prompts or len(prompts) == 1 and is_dummy(prompts[0]):
         prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
     else:
-        pre_seq_len = prompts.shape[2]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
@@ -255,9 +257,10 @@ async def _rpc_backward(
 
     if is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
     else:
-        pre_seq_len = prompts.shape[2]
-        prompts = [p.squeeze(0) for p in prompts.split(1)]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output