|
@@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend
|
|
|
from petals.server.memory_cache import Handle
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
|
|
+from petals.utils.convert_block import QuantType
|
|
|
from petals.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
+# We prioritize short inference requests and make them use a *merged* inference pool,
|
|
|
+# so they are processed without interruptions and extra overheads
|
|
|
+# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
|
|
+MAX_SHORT_INFERENCE_TOKENS = 128
|
|
|
+MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
|
|
+
|
|
|
|
|
|
async def run_rpc_forward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
@@ -127,9 +134,11 @@ async def iterate_rpc_inference(
|
|
|
active_adapter: Optional[str],
|
|
|
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
|
|
cache_handles: Sequence[Sequence[Handle]],
|
|
|
+ *,
|
|
|
max_length: int,
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int,
|
|
|
+ quant_type: QuantType,
|
|
|
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
|
|
@@ -138,6 +147,7 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
|
|
+ batch_size, length_increment, _ = hidden_states.shape
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
@@ -154,34 +164,40 @@ async def iterate_rpc_inference(
|
|
|
if not (len(requested_backends) == len(prompts)):
|
|
|
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
|
|
|
|
- length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
|
if prefix_length + length_increment > max_length:
|
|
|
raise ValueError(
|
|
|
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
|
f" exceeds pre-allocated maximum {max_length}"
|
|
|
)
|
|
|
|
|
|
+ merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
|
|
+ can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
|
|
priority = prioritizer.prioritize(
|
|
|
hidden_states,
|
|
|
hypo_ids,
|
|
|
points=point_per_piece,
|
|
|
requested_uids=requested_uids,
|
|
|
- type="inference",
|
|
|
- )
|
|
|
-
|
|
|
- inference_infos = tuple(
|
|
|
- InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
|
|
- for uid, handles in zip(requested_uids, cache_handles)
|
|
|
+ type="short_inference" if can_merge_pools else "inference",
|
|
|
)
|
|
|
|
|
|
- if hidden_states.numel() == 0:
|
|
|
- pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
|
|
|
- # when user wants to pre-allocate cache or check that server *can* allocate that cache
|
|
|
- else:
|
|
|
+ # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
|
|
+ # when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
|
|
+ if hidden_states.numel() > 0:
|
|
|
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
|
|
- (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
|
|
- hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
|
|
- )
|
|
|
+ if can_merge_pools:
|
|
|
+ inference_infos = tuple(
|
|
|
+ InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
|
|
+ for uid, handles in zip(requested_uids, cache_handles)
|
|
|
+ )
|
|
|
+ (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
|
|
+ hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
|
|
+ inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
|
|
+ (hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
+ hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
|
|
+ )
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
output_tensors = [
|