|
@@ -49,11 +49,11 @@ async def sequential_forward(
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
- assert (
|
|
|
- len(block_kwargs) in (0, 1, end_index - start_index)
|
|
|
- ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
|
|
if len(block_kwargs) == 1:
|
|
|
block_kwargs = block_kwargs * (end_index - start_index)
|
|
|
+ assert (
|
|
|
+ len(block_kwargs) in (0, end_index - start_index)
|
|
|
+ ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
assert is_dummy(prompts) or len(prompts) == len(
|
|
|
sequence_manager.block_uids
|