Your Name преди 1 година
родител
ревизия
62e780c054
променени са 1 файла, в които са добавени 4 реда и са изтрити 4 реда
  1. 4 4
      src/petals/client/sequential_autograd.py

+ 4 - 4
src/petals/client/sequential_autograd.py

@@ -49,11 +49,11 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    assert len(block_kwargs) in (
-        0,
-        1,
-        end_index - start_index,
+    assert (
+        len(block_kwargs) in (0, 1, end_index - start_index)
     ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
+    if len(block_kwargs) == 1:
+        block_kwargs = block_kwargs * (end_index - start_index)
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids