Procházet zdrojové kódy

[Refactor] extract block forward, backward and inference into a separate file (#435)

This PR does not change any functionality. It merely moves stuff around.
List of changes:

handler.py/_rpc_forward became block_methods/rpc_forward
handler.py/_rpc_backward became block_methods/rpc_backward
the math bits of rpc_inference were extracted into block_methods/iterate_rpc_inference

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: artek0chumak <artek.chumak@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
justheuristic před 2 roky
rodič
revize
ac9b546706
2 změnil soubory, kde provedl 214 přidání a 173 odebrání
  1. 195 0
      src/petals/server/block_functions.py
  2. 19 173
      src/petals/server/handler.py

+ 195 - 0
src/petals/server/block_functions.py

@@ -0,0 +1,195 @@
+"""
+This module implements server-side computations on served blocks: forward, backward and inference; used by handler
+"""
+from __future__ import annotations
+
+from typing import AsyncIterator, Optional, Sequence, Tuple, Union
+
+import torch
+from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.moe.expert_uid import ExpertUID
+from hivemind.proto import runtime_pb2
+from hivemind.utils.nested import nested_flatten
+
+from petals.data_structures import InferenceMetadata
+from petals.server.backend import TransformerBackend
+from petals.server.memory_cache import Handle
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.server.task_prioritizer import TaskPrioritizerBase
+from petals.utils.misc import DUMMY, is_dummy
+
+
+async def run_rpc_forward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        if not is_dummy(prompt):
+            hidden_states[:, : prompt.shape[1]] += prompt
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            active_adapter,
+            priority=priority,
+        )
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    return hidden_states
+
+
+async def run_rpc_backward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    inputs, grad_outputs, prompts = flat_tensors
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = []
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        if not is_dummy(prompt):
+            inputs[:, : prompt.shape[1]] += prompt
+        inter_inputs.append(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
+
+        assert isinstance(inputs, torch.Tensor)
+
+    if not is_dummy(prompts[-1]):
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
+
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
+
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape
+
+
+async def iterate_rpc_inference(
+    requested_uids: Sequence[ExpertUID],
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: Optional[str],
+    input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
+    cache_handles: Sequence[Sequence[Handle]],
+    max_length: int,
+    prioritizer: TaskPrioritizerBase,
+    points: int,
+) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
+    assert len(cache_handles) == len(requested_backends)
+
+    prefix_length = 0
+    point_per_piece = points / max_length if max_length > 0 else 0.0
+
+    async for request, step_metadata in input_iterator:
+        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+
+        # Cast inputs to backend dtype
+        hidden_states = hidden_states.to(requested_backends[0].dtype)
+        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+        # parse deep prompts (optional argument)
+        has_prompts = prompts is not None and not is_dummy(prompts)
+        if not has_prompts:
+            prompts = [None] * len(requested_backends)
+        else:
+            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
+
+        if not (len(requested_backends) == len(prompts)):
+            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
+        if prefix_length + length_increment > max_length:
+            raise ValueError(
+                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                f" exceeds pre-allocated maximum {max_length}"
+            )
+
+        priority = prioritizer.prioritize(
+            hidden_states,
+            hypo_ids,
+            points=point_per_piece,
+            requested_uids=requested_uids,
+            type="inference",
+        )
+
+        inference_infos = tuple(
+            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
+            for uid, handles in zip(requested_uids, cache_handles)
+        )
+
+        if hidden_states.numel() == 0:
+            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
+            # when user wants to pre-allocate cache or check that server *can* allocate that cache
+        else:
+            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
+            (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
+                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
+            )
+
+        # serialize and send last layer outputs
+        output_tensors = [
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+        ]
+        can_push = not has_prompts
+        yield output_tensors, can_push
+
+        # prepare for next step
+        prefix_length += length_increment

+ 19 - 173
src/petals/server/handler.py

@@ -6,7 +6,7 @@ import multiprocessing as mp
 import sys
 from enum import Enum
 from itertools import chain
-from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
 
 import torch
 from async_timeout import timeout
@@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 import petals
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
 from petals.server.backend import TransformerBackend
+from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
 from petals.server.memory_cache import Handle
-from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
-from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__name__)
 
@@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
-                active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
                 if not requested_uids:
@@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler):
                         f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
                     )
 
-                point_per_piece = points / max_length if max_length > 0 else 0.0
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
-                prefix_length = 0
 
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
-                    assert len(cache_handles) == len(requested_backends)
-                    first_request = request
                     background_tasks = set()
-                    async for request, metadata in self._iterate_inference_steps(
-                        first_request, requests, session_id, requested_uids, context
+                    async for output_tensors, can_push in iterate_rpc_inference(
+                        requested_uids=requested_uids,
+                        requested_backends=requested_backends,
+                        active_adapter=self._get_active_adapter(metadata),
+                        input_iterator=self._iterate_inference_steps(
+                            request, requests, session_id, requested_uids, context
+                        ),
+                        cache_handles=cache_handles,
+                        max_length=max_length,
+                        prioritizer=self._prioritizer,
+                        points=points,
                     ):
-                        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
-
-                        # Cast inputs to backend dtype
-                        hidden_states = hidden_states.to(requested_backends[0].dtype)
-                        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
-
-                        # parse deep prompts (optional argument)
-                        has_prompts = prompts is not None and not is_dummy(prompts)
-                        if not has_prompts:
-                            prompts = [None] * len(requested_backends)
-                        else:
-                            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-                            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
-
-                        if not (len(requested_backends) == len(prompts)):
-                            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
-
-                        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
-                        if prefix_length + length_increment > max_length:
-                            raise ValueError(
-                                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
-                                f" exceeds pre-allocated maximum {max_length}"
-                            )
-
-                        priority = self._prioritizer.prioritize(
-                            hidden_states,
-                            hypo_ids,
-                            points=point_per_piece,
-                            requested_uids=requested_uids,
-                            type="inference",
-                        )
-
-                        inference_infos = tuple(
-                            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
-                            for uid, handles in zip(requested_uids, cache_handles)
-                        )
-
-                        if hidden_states.numel() == 0:
-                            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
-                            # when user wants to pre-allocate cache or check that server *can* allocate that cache
-                        else:
-                            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
-                            (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
-                                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
-                            )
-
-                        # serialize and send last layer outputs
-                        output_tensors = [
-                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                            for result, proto in zip(
-                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                            )
-                        ]
-                        if not has_prompts:
+                        if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
                             background_tasks.add(task)  # Keep reference until it is done to save it from GC
                             task.add_done_callback(background_tasks.discard)
                         yield runtime_pb2.ExpertResponse(tensors=output_tensors)
 
-                        # prepare for next step
-                        prefix_length += length_increment
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
 
@@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            hidden_states = await run_rpc_forward(
                 *flat_inputs,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            hidden_states = await run_rpc_forward(
                 *flat_inputs,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
-            grads = await _rpc_backward(
+            grads = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
-            grads = await _rpc_backward(
+            grads = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler):
             result.update(block_info)
 
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
-
-
-async def _rpc_forward(
-    *flat_tensors: torch.Tensor,
-    requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = "",
-    prioritizer: TaskPrioritizerBase,
-    points: int = 0,
-) -> torch.Tensor:
-    """
-    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
-
-    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
-    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
-    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
-    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
-    """
-    hidden_states, prompts = flat_tensors
-    dtype = requested_backends[0].dtype
-    # check parse input tensors and cast dtypes
-    hidden_states = hidden_states.to(dtype)
-    assert hidden_states.ndim == 3
-    if prompts is None or is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
-    else:
-        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-
-    # Run a chain of requested backends
-    for backend, prompt in zip(requested_backends, prompts):
-        if not is_dummy(prompt):
-            hidden_states[:, : prompt.shape[1]] += prompt
-
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
-        )
-        (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states,
-            active_adapter,
-            priority=priority,
-        )
-        assert isinstance(hidden_states, torch.Tensor)
-        assert (
-            hidden_states.ndim == 3
-        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-    return hidden_states
-
-
-async def _rpc_backward(
-    *flat_tensors: torch.Tensor,
-    requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = "",
-    prioritizer: TaskPrioritizerBase,
-    points: int = 0,
-) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, prompts = flat_tensors
-    # Cast inputs & grad outputs to backend dtype
-    inputs = inputs.to(requested_backends[0].dtype)
-    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-
-    if prompts is None or is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
-    else:
-        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-
-    # Run a forward chain to collect intermediate inputs
-    # Note that we do not forward for the last module since we do not need its output
-    inter_inputs = []
-    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
-        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-        if not is_dummy(prompt):
-            inputs[:, : prompt.shape[1]] += prompt
-        inter_inputs.append(inputs)
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
-        )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
-
-        assert isinstance(inputs, torch.Tensor)
-
-    if not is_dummy(prompts[-1]):
-        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
-    inter_inputs.append(inputs)
-
-    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
-    grad_prompts_reversed = []
-    # Run a chain of requested backends
-    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
-        )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
-
-        assert isinstance(grad_outputs, torch.Tensor)
-        if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
-
-    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
-    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape