Browse Source

Implement direct server-to-server communication (#331)

Implement #226.
Alexander Borzunov 2 years ago
parent
commit
158013a671

+ 1 - 1
src/petals/__init__.py

@@ -9,7 +9,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "1.2.0.dev0"
+__version__ = "1.2.0.dev1"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 1 - 2
src/petals/cli/run_server.py

@@ -27,8 +27,7 @@ def main():
 
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
-    parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
-                                                                 "use the same name as in the converted model.")
+    parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
 
     parser.add_argument('--port', type=int, required=False,
                         help='Port this server listens to. '

+ 132 - 96
src/petals/client/inference_session.py

@@ -3,7 +3,8 @@ from __future__ import annotations
 import asyncio
 import itertools
 import time
-from typing import AsyncIterator, List, Optional
+import uuid
+from typing import AsyncIterator, List, Optional, Tuple
 
 import torch
 from hivemind import (
@@ -15,10 +16,10 @@ from hivemind import (
     serialize_torch_tensor,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import StubBase
+from hivemind.p2p import P2P
 from hivemind.proto import runtime_pb2
 
-from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
+from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
@@ -35,35 +36,48 @@ class _ServerInferenceSession:
 
     def __init__(
         self,
+        config: SequenceManagerConfig,
+        span: RemoteSpanInfo,
         uid: ModuleUID,
         rpc_info: RPCInfo,
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
         *,
-        timeout: float,
         max_length: int,
         **metadata,
     ):
-        self.uid, self.rpc_info = uid, rpc_info
+        self.config = config
+        self.span, self.uid, self.rpc_info = span, uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.timeout = timeout
-        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
+        self.session_id = str(uuid.uuid4())
+        self.session_metadata = dict(max_length=max_length, **metadata)
         self.stepped = False
         self.closed = False
 
+        self._position = 0
+        self.history = None  # Used in case of server failures to regenerate attention caches on new servers
+        self.next_session = None
+
     @classmethod
     async def create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
+        cls,
+        config: SequenceManagerConfig,
+        p2p: P2P,
+        span: RemoteSpanInfo,
+        uid: ModuleUID,
+        rpc_info: RPCInfo,
+        **metadata,
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
-            timeout,
+            config.request_timeout,
         )
-        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
+        return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -75,9 +89,11 @@ class _ServerInferenceSession:
 
     def step(
         self,
-        new_hidden_states: torch.Tensor,
+        inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
+        *,
+        step_id: str,
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
@@ -86,44 +102,84 @@ class _ServerInferenceSession:
         """
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
+
+        n_input_tokens = inputs.shape[1]
+        if self.history is None:
+            self.history = inputs
+        elif self.history.shape[1] == self._position:
+            self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
+        assert self.history.shape[1] == self._position + n_input_tokens, (
+            f"Broken input cache: span={self.span} shape={self.history.shape} "
+            f"position={self._position} n_input_tokens={n_input_tokens}"
+        )
+
+        if not self.stepped:
+            inputs = self.history  # Pass full inputs including prefix
+        else:
+            inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
-            assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.shape[0] == self.num_blocks
-            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
-            assert prompts.shape[2] <= new_hidden_states.shape[1]
-            assert prompts.shape[3] == new_hidden_states.shape[2]
+            assert prompts.shape[1] in (inputs.shape[0], 1)
+            assert prompts.shape[2] <= inputs.shape[1]
+            assert prompts.shape[3] == inputs.shape[2]
 
         if hypo_ids is None or is_dummy(hypo_ids):
             hypo_ids = DUMMY
         else:
-            assert len(hypo_ids) == len(new_hidden_states)
+            assert len(hypo_ids) == len(inputs)
             assert hypo_ids.dtype == torch.int64
 
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, prompts, hypo_ids)
+        input_tensors = (inputs, prompts, hypo_ids)
+
+        request_metadata = dict(session_id=self.session_id, step_id=step_id)
+        if not self.stepped:
+            request_metadata.update(self.session_metadata)
+        elif self.config.use_server_to_server:
+            next_servers = self._collect_next_servers()
+            if next_servers:
+                request_metadata["next_servers"] = next_servers
+
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
+                        for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
-                    metadata=self._serialized_metadata if not self.stepped else None,
+                    metadata=MSGPackSerializer.dumps(request_metadata),
                 )
             )
         )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        assert (
+            outputs[0].shape == inputs.shape
+        ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
+
+        self._position += n_input_tokens
+
         return outputs[0]
 
+    def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
+        next_servers = []
+        session = self.next_session
+        while session is not None and session.stepped:
+            next_servers.append(
+                (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
+            )
+            session = session.next_session
+        return next_servers
+
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
-        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
+        return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
 
     def close(self):
         """Finish a given inference session, close the underlying connection"""
@@ -163,13 +219,15 @@ class InferenceSession:
     def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
         self._sequence_manager = sequence_manager
         self._closed = False
-        self._chosen_spans = []
         self._server_sessions = []
-        self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
         self._max_length = max_length
         self.last_token_id = None
 
+    @property
+    def num_blocks(self) -> int:
+        return len(self._sequence_manager)
+
     @property
     def position(self) -> int:
         return self._position
@@ -178,15 +236,15 @@ class InferenceSession:
         server_sessions = []
         try:
             for span in chosen_spans:
-                stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id)
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 session = RemoteExpertWorker.run_coroutine(
                     _ServerInferenceSession.create(
-                        stub,
+                        self._sequence_manager.config,
+                        self._sequence_manager.state.p2p,
+                        span,
                         span_uids,
                         rpc_info=self._sequence_manager.rpc_info,
-                        timeout=self._sequence_manager.config.request_timeout,
                         max_length=self._max_length,
                         **metadata,
                     )
@@ -206,7 +264,7 @@ class InferenceSession:
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
 
     def __enter__(self) -> "InferenceSession":
-        assert not self._closed and not self._chosen_spans
+        assert not self._closed and not self._server_sessions
         return self
 
     def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
@@ -214,16 +272,17 @@ class InferenceSession:
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
 
-        n_blocks = len(self._sequence_manager)
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
-            assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+            assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
 
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         prompts = prompts.cpu()
+        step_id = str(uuid.uuid4())
 
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
@@ -233,97 +292,74 @@ class InferenceSession:
 
         server_idx = 0
         block_idx = 0
-        recovery_until = -1  # Recovery mode is disabled until a failure happens
-        while block_idx < n_blocks:
+        while block_idx < self.num_blocks:
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
-                span = None
+                server_session = None
                 try:
-                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        # If there is a failed server session, this code closes it
-                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
-
-                        n_prev_spans = len(self._chosen_spans)
-                        update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
-                        if attempt_no >= 1 and update_end > recovery_until:
-                            logger.info(
-                                f"Due to a server failure, remote attention caches "
-                                f"from block {block_idx} to {update_end} will be regenerated"
-                            )
-                        recovery_until = max(recovery_until, update_end)
-
-                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
-                        # make_sequence() could return a longer sequence
-                        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
-                        updated_sessions = self._enter_server_sessions(updated_spans)
-                        logger.debug(
-                            f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
-                        )
-
-                        # If there is a failed span, this code replaces it, otherwise it just adds new ones
-                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
-                        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
-                        recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
-                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
-                            len(updated_spans) - 1
-                        )
-                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
-                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
-                            f"{len(self._server_inputs)} inputs"
-                        )
-
-                    session = self._server_sessions[server_idx]
-                    span = self._chosen_spans[server_idx]
-
-                    if self._server_inputs[server_idx] is None:
-                        self._server_inputs[server_idx] = inputs
-                    elif self._server_inputs[server_idx].shape[1] == self._position:
-                        self._server_inputs[server_idx] = torch.cat(
-                            [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
-                        )
-                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
-                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
-                        f"position={self._position} n_input_tokens={n_input_tokens}"
-                    )
-
-                    if not session.stepped:
-                        inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
-                    else:
-                        inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+                    if not self._server_sessions or attempt_no >= 1:
+                        self._update_sequence(server_idx, block_idx, attempt_no)
 
-                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
-                    assert (
-                        inputs.shape == outputs.shape
-                    ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
+                    server_session = self._server_sessions[server_idx]
+                    inputs = server_session.step(
+                        inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
+                    )
 
-                    inputs = outputs
                     server_idx += 1
-                    block_idx = span.end
-                    self._sequence_manager.on_request_success(span.peer_id)
+                    block_idx = server_session.span.end
+                    self._sequence_manager.on_request_success(server_session.span.peer_id)
                     break
                 except Exception as e:
-                    self._sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                    self._sequence_manager.on_request_failure(
+                        server_session.span.peer_id if server_session is not None else None
+                    )
                     if attempt_no + 1 == self._sequence_manager.config.max_retries:
                         raise
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
-                        f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}"
+                        f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
                     maybe_log_traceback(e)
                     time.sleep(delay)
 
         self._position += n_input_tokens
-        inputs = inputs[:, -n_input_tokens:]
-        outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+        outputs = inputs[:, -n_input_tokens:]
+        outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
         return outputs
 
+    def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
+        # If there is a failed server session, this code closes it
+        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+
+        n_prev_spans = len(self._server_sessions)
+        update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
+        if attempt_no >= 1:
+            logger.info(
+                f"Due to a server failure, remote attention caches "
+                f"from block {block_idx} to {update_end} will be regenerated"
+            )
+
+        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
+        # make_sequence() could return a longer sequence
+        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
+        updated_sessions = self._enter_server_sessions(updated_spans)
+        logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
+
+        # If there is a failed span, this code replaces it, otherwise it just adds new ones
+        if server_idx < n_prev_spans:
+            updated_sessions[0].history = self._server_sessions[server_idx].history
+        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
+
+        # Update links to the next server session for direct server-to-server communication via rpc_push()
+        for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
+            self._server_sessions[i].next_session = self._server_sessions[i + 1]
+
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
         if not self._closed:
-            self._server_inputs.clear()
             self._exit_server_sessions(self._server_sessions)
             self._server_sessions.clear()
-            self._chosen_spans.clear()
             self._closed = True
 
     def __exit__(self, *exc_details):

+ 1 - 0
src/petals/client/routing/sequence_manager.py

@@ -34,6 +34,7 @@ class SequenceManagerConfig:
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
 
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+    use_server_to_server: bool = True  # Use direct server-to-server communication
 
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     update_period: float = 60  # refresh DHT information once in this many seconds

+ 177 - 23
src/petals/server/handler.py

@@ -2,6 +2,9 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
+import multiprocessing.managers
+import sys
+from concurrent.futures import ThreadPoolExecutor
 from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
@@ -11,6 +14,7 @@ from hivemind import (
     DHT,
     MSGPackSerializer,
     P2PContext,
+    PeerID,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     nested_flatten,
@@ -25,7 +29,7 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 import petals
-from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
 from petals.server.backend import TransformerBackend
 from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
@@ -34,6 +38,23 @@ from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__name__)
 
+
+# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
+sys.modules["runtime_pb2"] = runtime_pb2
+
+# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
+
+_OriginalAutoProxy = multiprocessing.managers.AutoProxy
+
+
+def patched_autoproxy(*args, manager_owned=True, **kwargs):
+    # Calling original AutoProxy without the unwanted key argument
+    return _OriginalAutoProxy(*args, **kwargs)
+
+
+multiprocessing.managers.AutoProxy = patched_autoproxy
+
+
 CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
 
 
@@ -47,6 +68,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
+        dht_prefix: str,
+        push_manager: multiprocessing.managers.SyncManager,
+        session_queues: Dict[str, multiprocessing.managers.BaseProxy],  # BaseProxy for queue.Queue
         inference_max_length: int,
         request_timeout: float,
         session_timeout: float,
@@ -56,6 +80,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
+        self.dht_prefix = dht_prefix
+        self._push_manager = push_manager
+        self._session_queues = session_queues
+        self._executor = ThreadPoolExecutor(max_workers=float("inf"))  # For waiting on self.session_queues
+
         self.inference_max_length = inference_max_length
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
@@ -96,7 +125,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         self,
         requests: AsyncIterator[runtime_pb2.ExpertRequest],
         context: P2PContext,
-    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
 
         async with timeout(self.session_timeout):
@@ -113,6 +142,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
+                session_id = metadata.get("session_id")
 
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
@@ -133,7 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
                     assert len(cache_handles) == len(requested_backends)
-                    while request.tensors:  # iterate while user is willing to supply tensors
+                    first_request = request
+                    background_tasks = set()
+                    async for request, metadata in self._iterate_inference_steps(
+                        first_request, requests, session_id, requested_uids, context
+                    ):
                         hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                         # Cast inputs to backend dtype
@@ -141,7 +175,8 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
                         # parse deep prompts (optional argument)
-                        if prompts is None or is_dummy(prompts):
+                        has_prompts = prompts is not None and not is_dummy(prompts)
+                        if not has_prompts:
                             prompts = [None] * len(requested_backends)
                         else:
                             prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
@@ -180,25 +215,136 @@ class TransformerConnectionHandler(ConnectionHandler):
                             )
 
                         # serialize and send last layer outputs
-                        yield runtime_pb2.ExpertResponse(
-                            tensors=[
-                                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                                for result, proto in zip(
-                                    (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                                )
-                            ]
-                        )
+                        output_tensors = [
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                            for result, proto in zip(
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
+                            )
+                        ]
+                        if not has_prompts:
+                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
+                            background_tasks.add(task)  # Keep reference until it is done to save it from GC
+                            task.add_done_callback(background_tasks.discard)
+                        yield runtime_pb2.ExpertResponse(tensors=output_tensors)
 
                         # prepare for next step
-                        prefix_length += hidden_states.shape[1]
-                        try:
-                            request = await asyncio.wait_for(anext(requests), self.step_timeout)
-                        except asyncio.TimeoutError:
-                            self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
-                            return
+                        prefix_length += length_increment
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
 
+    async def _iterate_inference_steps(
+        self,
+        first_request: runtime_pb2.ExpertRequest,
+        requests: AsyncIterator[runtime_pb2.ExpertRequest],
+        session_id: Optional[str],
+        requested_uids: Sequence[str],
+        context: P2PContext,
+    ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
+        loop = asyncio.get_event_loop()
+        if session_id is not None:
+            push_queue = self._push_manager.Queue()
+            self._session_queues[session_id] = push_queue
+
+        processed_step_ids = set()
+        n_pushes = n_late_pushes = 0
+        request = first_request
+        anext_task = get_push_task = None
+        try:
+            while request.tensors:  # iterate while user is willing to supply tensors
+                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                step_id = metadata.get("step_id")
+
+                pushed = metadata.get("pushed")
+                if pushed:
+                    n_pushes += 1
+
+                if step_id is None or step_id not in processed_step_ids:
+                    yield request, metadata
+                    if step_id is not None:
+                        processed_step_ids.add(step_id)
+                elif pushed:
+                    n_late_pushes += 1
+                    self._log_request(
+                        "rpc_inference.push",
+                        requested_uids,
+                        context,
+                        warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+                    )
+
+                # Wait for the next request, coming either from the `requests` iterator or `push_queue`
+                if anext_task is None:
+                    anext_task = asyncio.create_task(anext(requests))
+                if get_push_task is None:
+                    if session_id is not None:
+                        get_push_task = loop.run_in_executor(self._executor, push_queue.get)
+                    else:
+                        get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
+                done, _ = await asyncio.wait(
+                    [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
+                )
+
+                if anext_task in done:
+                    request = await anext_task
+                    anext_task = None
+                elif get_push_task in done:
+                    request = await get_push_task
+                    get_push_task = None
+                else:
+                    self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
+                    anext_task.cancel()
+                    get_push_task.cancel()
+                    return
+        except:
+            logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
+            raise
+        finally:
+            if session_id is not None:
+                push_queue.put(None)  # Stop thread for get_push_task
+                del self._session_queues[session_id]
+
+    async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        """Directly push activation tensors from one server to another"""
+
+        requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_push", requested_uids, context)
+
+        metadata = MSGPackSerializer.loads(request.metadata)
+        session_id = metadata["session_id"]
+        self._session_queues[session_id].put(request)
+        return runtime_pb2.ExpertResponse()
+
+    async def _push_outputs(
+        self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
+    ) -> None:
+        try:
+            next_servers = metadata.get("next_servers")
+            if not next_servers:
+                return
+
+            next_peer_id, next_session_id, next_start, next_end = next_servers[0]
+            next_peer_id = PeerID.from_base58(next_peer_id)
+            next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end))
+
+            # Sending hidden states serialized with output_schema to avoid double serialization
+            next_tensors = [serialized_outputs] + request.tensors[1:]
+            next_metadata = metadata.copy()
+            next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)
+
+            stub = self.get_stub(self._p2p, next_peer_id)
+            await stub.rpc_push(
+                runtime_pb2.ExpertRequest(
+                    uid=next_uid,
+                    tensors=next_tensors,
+                    metadata=MSGPackSerializer.dumps(next_metadata),
+                ),
+                timeout=self.request_timeout,
+            )
+        except Exception:
+            logger.debug(
+                f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:",
+                exc_info=True,
+            )
+
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         async with timeout(self.request_timeout):
             # Parse request and prepare backends
@@ -348,7 +494,7 @@ class TransformerConnectionHandler(ConnectionHandler):
     @contextlib.asynccontextmanager
     async def _allocate_cache(
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
-    ) -> Sequence[Sequence[Handle, ...]]:
+    ) -> Sequence[Sequence[Handle]]:
         """
         Allocate memory cache for all transformer blocks, return cache handle
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
@@ -358,7 +504,13 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield nested_pack(handles, descriptors)
 
     def _log_request(
-        self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
+        self,
+        method: str,
+        uids: Optional[Sequence[ModuleUID]],
+        context: P2PContext,
+        *,
+        debug: Optional[str] = None,
+        warning: Optional[str] = None,
     ) -> None:
         if uids is not None:
             friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
@@ -370,10 +522,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         friendly_remote_id = "..." + str(context.remote_id)[-6:]
 
         message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
-        if warning is None:
-            logger.info(message)
-        else:
+        if warning is not None:
             logger.warning(f"{message}: {warning}")
+        elif debug is not None:
+            logger.debug(f"{message}: {debug}")
+        else:
+            logger.info(message)
 
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         """Return metadata about stored block uids and current load"""

+ 22 - 10
src/petals/server/server.py

@@ -45,7 +45,7 @@ class Server:
         self,
         *,
         initial_peers: List[str],
-        prefix: Optional[str],
+        dht_prefix: Optional[str],
         converted_model_name_or_path: str,
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
@@ -105,13 +105,13 @@ class Server:
             revision=revision,
         )
 
-        if prefix is None:
-            prefix = self.block_config.dht_prefix
-        assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+        if dht_prefix is None:
+            dht_prefix = self.block_config.dht_prefix
+        assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (
             f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
-            f"Please specify another --prefix manually when starting a server"
+            f"Please specify another --dht_prefix manually when starting a server"
         )
-        self.prefix = prefix
+        self.dht_prefix = dht_prefix
 
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
@@ -121,7 +121,8 @@ class Server:
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
         self.module_uids = [
-            f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
+            f"{self.dht_prefix}{UID_DELIMITER}{block_index}"
+            for block_index in range(self.block_config.num_hidden_layers)
         ]
 
         if dht_client_mode is None:
@@ -258,7 +259,7 @@ class Server:
             block_indices = self._choose_blocks()
             self.module_container = ModuleContainer.create(
                 dht=self.dht,
-                prefix=self.prefix,
+                dht_prefix=self.dht_prefix,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
@@ -359,7 +360,7 @@ class ModuleContainer(threading.Thread):
         cls,
         *,
         dht: DHT,
-        prefix: str,
+        dht_prefix: str,
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
@@ -382,7 +383,7 @@ class ModuleContainer(threading.Thread):
         should_validate_reachability: bool,
         **kwargs,
     ) -> ModuleContainer:
-        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         joining_announcer = ModuleAnnouncerThread(
             module_uids,
             dht,
@@ -459,6 +460,7 @@ class ModuleContainer(threading.Thread):
 
         return cls(
             dht,
+            dht_prefix,
             blocks,
             throughput=throughput,
             update_period=update_period,
@@ -469,6 +471,7 @@ class ModuleContainer(threading.Thread):
     def __init__(
         self,
         dht: DHT,
+        dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         *,
         inference_max_length: int,
@@ -486,10 +489,17 @@ class ModuleContainer(threading.Thread):
 
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+
+        self.push_manager = mp.Manager()
+        self.push_manager.__enter__()
+        session_queues = self.push_manager.dict()
         self.conn_handlers = [
             TransformerConnectionHandler(
                 dht,
                 self.module_backends,
+                dht_prefix=dht_prefix,
+                push_manager=self.push_manager,
+                session_queues=session_queues,
                 inference_max_length=inference_max_length,
                 request_timeout=request_timeout,
                 session_timeout=session_timeout,
@@ -497,6 +507,7 @@ class ModuleContainer(threading.Thread):
             )
             for _ in range(num_handlers)
         ]
+
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
@@ -577,6 +588,7 @@ class ModuleContainer(threading.Thread):
         logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
             handler.shutdown()
+        self.push_manager.__exit__(None, None, None)
 
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools: