|
@@ -3,12 +3,13 @@ This module implements server-side computations on served blocks: forward, backw
|
|
|
"""
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-from typing import AsyncIterator, Optional, Sequence, Tuple, Union
|
|
|
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.nested import nested_flatten
|
|
|
|
|
|
from petals.data_structures import InferenceMetadata
|
|
@@ -18,6 +19,7 @@ from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
|
|
from petals.utils.convert_block import QuantType
|
|
|
from petals.utils.misc import DUMMY, is_dummy
|
|
|
+from petals.utils.packaging import unpack_args_kwargs
|
|
|
|
|
|
# We prioritize short inference requests and make them use a *merged* inference pool,
|
|
|
# so they are processed without interruptions and extra overheads
|
|
@@ -25,6 +27,8 @@ from petals.utils.misc import DUMMY, is_dummy
|
|
|
MAX_SHORT_INFERENCE_TOKENS = 128
|
|
|
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
|
|
|
async def run_rpc_forward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
@@ -32,6 +36,7 @@ async def run_rpc_forward(
|
|
|
active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
+ args_structure: Any = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
@@ -41,7 +46,11 @@ async def run_rpc_forward(
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
"""
|
|
|
- hidden_states, prompts = flat_tensors
|
|
|
+ if args_structure is not None:
|
|
|
+ # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
+ flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
+ hidden_states, prompts, *_ = flat_tensors
|
|
|
+
|
|
|
dtype = requested_backends[0].dtype
|
|
|
# check parse input tensors and cast dtypes
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
@@ -79,8 +88,13 @@ async def run_rpc_backward(
|
|
|
active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
+ args_structure: Any = None,
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
- inputs, grad_outputs, prompts = flat_tensors
|
|
|
+ if args_structure is not None:
|
|
|
+ # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
+ flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
+ inputs, grad_outputs, prompts, *_ = flat_tensors
|
|
|
+
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
@@ -139,6 +153,7 @@ async def iterate_rpc_inference(
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int,
|
|
|
quant_type: QuantType,
|
|
|
+ args_structure: Any = None,
|
|
|
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
|
|
@@ -146,7 +161,12 @@ async def iterate_rpc_inference(
|
|
|
point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
- hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
|
|
+ flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
|
|
+ if args_structure is not None:
|
|
|
+ # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
+ flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
+
|
|
|
+ hidden_states, prompts, hypo_ids, *_ = flat_tensors
|
|
|
batch_size, length_increment, _ = hidden_states.shape
|
|
|
|
|
|
# Cast inputs to backend dtype
|