소스 검색

check num block kwargs

Your Name 1 년 전
부모
커밋
62e780c054
1개의 변경된 파일4개의 추가작업 그리고 4개의 파일을 삭제
  1. 4 4
      src/petals/client/sequential_autograd.py

+ 4 - 4
src/petals/client/sequential_autograd.py

@@ -49,11 +49,11 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    assert len(block_kwargs) in (
-        0,
-        1,
-        end_index - start_index,
+    assert (
+        len(block_kwargs) in (0, 1, end_index - start_index)
     ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
+    if len(block_kwargs) == 1:
+        block_kwargs = block_kwargs * (end_index - start_index)
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids