Quellcode durchsuchen

Add customizable input tensors (#445)

Artem Chumachenko vor 2 Jahren
Ursprung
Commit
568f21dc3b

+ 19 - 12
src/petals/client/inference_session.py

@@ -7,22 +7,17 @@ import uuid
 from typing import AsyncIterator, List, Optional, Tuple
 
 import torch
-from hivemind import (
-    MSGPackSerializer,
-    anext,
-    deserialize_torch_tensor,
-    get_logger,
-    nested_flatten,
-    serialize_torch_tensor,
-)
+from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2P
 from hivemind.proto import runtime_pb2
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
-from petals.utils.misc import DUMMY, is_dummy
+from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
+from petals.utils.packaging import pack_args_kwargs
 
 logger = get_logger(__name__)
 
@@ -128,13 +123,13 @@ class _ServerInferenceSession:
             assert prompts.shape[3] == inputs.shape[2]
 
         if hypo_ids is None or is_dummy(hypo_ids):
-            hypo_ids = DUMMY
+            hypo_ids = DUMMY_INT64
         else:
             assert len(hypo_ids) == len(inputs)
             assert hypo_ids.dtype == torch.int64
 
         # serialize inputs and put them into the queue
-        input_tensors = (inputs, prompts, hypo_ids)
+        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
 
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
@@ -144,13 +139,25 @@ class _ServerInferenceSession:
             if next_servers:
                 request_metadata["next_servers"] = next_servers
 
+        request_metadata["args_structure"] = args_structure
+
+        # TODO: make possible to use different compression method for different tensors
+        server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
+        compression = server_side_inference_schema[0].compression
+        inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
+
+        # TODO: create more explicit way to check servers schema and client's structure
+        assert len(input_tensors) >= len(
+            server_side_inference_schema
+        ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
+
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
+                        for tensor, proto in zip(input_tensors, inference_schema)
                     ],
                     metadata=MSGPackSerializer.dumps(request_metadata),
                 )

+ 15 - 24
src/petals/client/remote_forward_backward.py

@@ -12,6 +12,7 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.data_structures import ModuleUID, RPCInfo
@@ -84,26 +85,20 @@ async def run_remote_forward(
     kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
 
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-    forward_inputs = (inputs, kwargs)
-
-    # Modify forward_schema to support prompts
+    forward_inputs = tuple(nested_flatten((inputs, kwargs)))
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
-
-    if not nested_compare(forward_inputs, forward_schema_with_prompts):
-        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
-    forward_inputs = nested_flatten(forward_inputs)
+    compression = args_schema[0].compression
+    forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+    # TODO: create more explicit way to check servers schema and client's structure
+    assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
+            for tensor, proto in zip(inputs, forward_schema)
         )
     )
 
@@ -119,9 +114,7 @@ async def run_remote_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    inputs: torch.Tensor,
-    grad_outputs: List[torch.Tensor],
-    *extra_tensors: torch.Tensor,
+    *inputs_and_grad_outputs: torch.Tensor,
     config: SequenceManagerConfig,
     metadata: Optional[bytes] = None,
     **kwargs,
@@ -131,16 +124,14 @@ async def run_remote_backward(
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
-
-    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
-
-    # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
-    # TODO generalize this
-    prompts_schema = next(iter(args_schema))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
+    outputs_schema = rpc_info["outputs_schema"]
+    compression = args_schema[0].compression
+    backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
+    # TODO: create more explicit way to check servers schema and client's structure
+    assert (
+        len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
+    ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()

+ 9 - 2
src/petals/client/routing/sequence_manager.py

@@ -487,14 +487,21 @@ class RemoteSequenceManager:
             return 0
         return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
 
-    def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
+    def get_request_metadata(
+        self, protocol: str, args_structure: Any = None, *args, **kwargs
+    ) -> Optional[Dict[str, Any]]:
         """
         :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
+        :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
         :param args: request-specific inputs, typically block uids and input tensors
         :param kwargs: additional request context, such as remote peer ID
         :returns: msgpack-serialized metadata dict that will be passed alongside a given request
         """
-        return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
+        return dict(
+            points=self.policy.get_points(protocol, *args, **kwargs),
+            active_adapter=self.config.active_adapter,
+            args_structure=args_structure,
+        )
 
     def shutdown(self):
         self._thread.shutdown()

+ 13 - 7
src/petals/client/sequential_autograd.py

@@ -16,6 +16,7 @@ from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_
 from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
+from petals.utils.packaging import pack_args_kwargs
 
 logger = get_logger(__name__)
 
@@ -67,15 +68,17 @@ async def sequential_forward(
                 span = sequences.popleft()
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
-                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+                flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
+                metadata = sequence_manager.get_request_metadata(
+                    "rpc_forward", args_structure, span_uids, *flat_tensors
+                )
                 (outputs,) = await run_remote_forward(
                     span_uids,
                     stub,
                     sequence_manager.rpc_info,
-                    *inputs_and_prompts,
+                    *flat_tensors,
                     config=sequence_manager.config,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
@@ -149,18 +152,21 @@ async def sequential_backward(
                     inputs = intermediate_inputs.pop()
                     span = forward_sequences.pop()
 
+                grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
+                flat_tensors, args_structure = pack_args_kwargs(
+                    inputs, *grad_outputs_cpu, prompts[span.start : span.end]
+                )
+
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 metadata = sequence_manager.get_request_metadata(
-                    "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
+                    "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
                 )
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
                     stub,
                     sequence_manager.rpc_info,
-                    inputs,
-                    grad_outputs,
-                    prompts[span.start : span.end],
+                    *flat_tensors,
                     config=sequence_manager.config,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )

+ 24 - 4
src/petals/server/block_functions.py

@@ -3,12 +3,13 @@ This module implements server-side computations on served blocks: forward, backw
 """
 from __future__ import annotations
 
-from typing import AsyncIterator, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
 import torch
 from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_flatten
 
 from petals.data_structures import InferenceMetadata
@@ -18,6 +19,7 @@ from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import TaskPrioritizerBase
 from petals.utils.convert_block import QuantType
 from petals.utils.misc import DUMMY, is_dummy
+from petals.utils.packaging import unpack_args_kwargs
 
 # We prioritize short inference requests and make them use a *merged* inference pool,
 # so they are processed without interruptions and extra overheads
@@ -25,6 +27,8 @@ from petals.utils.misc import DUMMY, is_dummy
 MAX_SHORT_INFERENCE_TOKENS = 128
 MAX_NF4_SHORT_INFERENCE_TOKENS = 1
 
+logger = get_logger(__name__)
+
 
 async def run_rpc_forward(
     *flat_tensors: torch.Tensor,
@@ -32,6 +36,7 @@ async def run_rpc_forward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
+    args_structure: Any = None,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -41,7 +46,11 @@ async def run_rpc_forward(
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    hidden_states, prompts = flat_tensors
+    if args_structure is not None:
+        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
+        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
+    hidden_states, prompts, *_ = flat_tensors
+
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
@@ -79,8 +88,13 @@ async def run_rpc_backward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
+    args_structure: Any = None,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, prompts = flat_tensors
+    if args_structure is not None:
+        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
+        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
+    inputs, grad_outputs, prompts, *_ = flat_tensors
+
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
@@ -139,6 +153,7 @@ async def iterate_rpc_inference(
     prioritizer: TaskPrioritizerBase,
     points: int,
     quant_type: QuantType,
+    args_structure: Any = None,
 ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
     assert len(cache_handles) == len(requested_backends)
 
@@ -146,7 +161,12 @@ async def iterate_rpc_inference(
     point_per_piece = points / max_length if max_length > 0 else 0.0
 
     async for request, step_metadata in input_iterator:
-        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+        flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
+        if args_structure is not None:
+            # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
+            flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
+
+        hidden_states, prompts, hypo_ids, *_ = flat_tensors
         batch_size, length_increment, _ = hidden_states.shape
 
         # Cast inputs to backend dtype

+ 10 - 0
src/petals/server/handler.py

@@ -151,6 +151,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
+                args_structure = metadata.get("args_structure")
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                 assert isinstance(
@@ -180,6 +181,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         prioritizer=self._prioritizer,
                         points=points,
                         quant_type=self.quant_type,
+                        args_structure=args_structure,
                     ):
                         if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
@@ -356,6 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
+            args_structure = metadata.get("args_structure")
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
@@ -366,6 +369,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
+                args_structure=args_structure,
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -383,6 +387,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
+            args_structure = metadata.get("args_structure")
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
@@ -393,6 +398,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
+                args_structure=args_structure,
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -434,6 +440,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
+            args_structure = metadata.get("args_structure")
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
@@ -444,6 +451,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
+                args_structure=args_structure,
             )
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -459,6 +467,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
+            args_structure = metadata.get("args_structure")
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
@@ -469,6 +478,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
+                args_structure=args_structure,
             )
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):

+ 2 - 0
src/petals/utils/misc.py

@@ -2,6 +2,8 @@ import torch
 
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 
+DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
+
 
 def is_dummy(tensor: torch.Tensor):
     return tensor.numel() == 0

+ 49 - 0
src/petals/utils/packaging.py

@@ -0,0 +1,49 @@
+from typing import Any, Dict, List, Tuple
+
+import torch
+from hivemind import nested_flatten, nested_pack
+
+# TODO: Move functions to hivemind
+
+
+def _mark_masked_tensor(index: int) -> bytes:
+    return b"__T" + str(index).encode()
+
+
+def _is_masked_tensor(item: Any) -> bool:
+    return isinstance(item, bytes) and item.startswith(b"__T")
+
+
+def _get_tensor_index(item: bytes) -> int:
+    return int(item[3:])
+
+
+def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
+    """
+    Check the function's arguments and pack all tensors into different flattened lists.
+    :returns: a flattened list of tensors and args and kwargs, where tensors were masked
+    """
+    masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
+    for value in nested_flatten((args, kwargs)):
+        if isinstance(value, torch.Tensor):
+            tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
+            if tensor_index == len(flat_tensors):
+                flat_tensors.append(value)
+            masked_flat_values.append(_mark_masked_tensor(tensor_index))
+        else:
+            masked_flat_values.append(value)
+    return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
+
+
+def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
+    """
+    Restore arguments after `pack_args_kwargs` function.
+    :returns: list of args and dict of kwargs
+    """
+    return nested_pack(
+        (
+            value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
+            for value in nested_flatten(args_structure)
+        ),
+        args_structure,
+    )

+ 29 - 0
tests/test_aux_functions.py

@@ -3,10 +3,13 @@ import sys
 
 import pytest
 import torch
+from hivemind import nested_compare, nested_flatten
 
 from petals import AutoDistributedConfig
 from petals.server.throughput import measure_compute_rps
 from petals.utils.convert_block import QuantType
+from petals.utils.misc import DUMMY, is_dummy
+from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
 from test_utils import MODEL_NAME
 
 
@@ -44,3 +47,29 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo
         inference=inference,
     )
     assert isinstance(compute_rps, float) and compute_rps > 0
+
+
+@pytest.mark.forked
+def test_pack_inputs():
+    x = torch.ones(3)
+    y = torch.arange(5)
+    z = DUMMY
+
+    args = (x, z, None, (y, y), z)
+    kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
+
+    flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
+
+    assert len(flat_tensors) == 5
+    assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
+
+    restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
+
+    assert len(restored_args) == len(args)
+    assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
+    assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
+    for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
+        if isinstance(original, torch.Tensor):
+            assert torch.all(original == restored)
+        else:
+            assert original == restored