|
@@ -51,7 +51,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
if not requested_uids:
|
|
if not requested_uids:
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
|
|
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
|
|
- if max_length not in range(0, self.inference_max_length):
|
|
|
|
|
|
+ if not 0 <= max_length <= self.inference_max_length:
|
|
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
|
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|