Ver código fonte

standardize checking block_kwargs

Your Name 1 ano atrás
pai
commit
056cd77f11
1 arquivos alterados com 3 adições e 3 exclusões
  1. 3 3
      src/petals/client/sequential_autograd.py

+ 3 - 3
src/petals/client/sequential_autograd.py

@@ -49,11 +49,11 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    assert (
-        len(block_kwargs) in (0, 1, end_index - start_index)
-    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     if len(block_kwargs) == 1:
         block_kwargs = block_kwargs * (end_index - start_index)
+    assert (
+        len(block_kwargs) in (0, end_index - start_index)
+    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids