|
@@ -49,8 +49,11 @@ async def sequential_forward(
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
- assert len(block_kwargs) in (0, 1, end_index - start_index), \
|
|
|
- f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
|
|
+ assert len(block_kwargs) in (
|
|
|
+ 0,
|
|
|
+ 1,
|
|
|
+ end_index - start_index,
|
|
|
+ ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
assert is_dummy(prompts) or len(prompts) == len(
|
|
|
sequence_manager.block_uids
|