Your Name 1 vuosi sitten
vanhempi
commit
4393d99e78

+ 1 - 1
src/petals/client/inference_session.py

@@ -4,7 +4,7 @@ import asyncio
 import itertools
 import time
 import uuid
-from typing import AsyncIterator, List, Optional, Tuple, Sequence
+from typing import AsyncIterator, List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor

+ 5 - 2
src/petals/client/sequential_autograd.py

@@ -49,8 +49,11 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    assert len(block_kwargs) in (0, 1, end_index - start_index), \
-        f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
+    assert len(block_kwargs) in (
+        0,
+        1,
+        end_index - start_index,
+    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids