Преглед изворни кода

standardize checking block_kwargs

Your Name пре 1 година
родитељ
комит
056cd77f11
1 измењених фајлова са 3 додато и 3 уклоњено
  1. 3 3
      src/petals/client/sequential_autograd.py

+ 3 - 3
src/petals/client/sequential_autograd.py

@@ -49,11 +49,11 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    assert (
-        len(block_kwargs) in (0, 1, end_index - start_index)
-    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     if len(block_kwargs) == 1:
         block_kwargs = block_kwargs * (end_index - start_index)
+    assert (
+        len(block_kwargs) in (0, end_index - start_index)
+    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids