Browse Source

Implement direct server-to-server communication (#331)

Implement #226.
Alexander Borzunov 2 years ago
parent
commit
158013a671

+ 1 - 1
src/petals/__init__.py

@@ -9,7 +9,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "1.2.0.dev0"
+__version__ = "1.2.0.dev1"
 
 
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 1 - 2
src/petals/cli/run_server.py

@@ -27,8 +27,7 @@ def main():
 
 
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
-    parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
-                                                                 "use the same name as in the converted model.")
+    parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
 
 
     parser.add_argument('--port', type=int, required=False,
     parser.add_argument('--port', type=int, required=False,
                         help='Port this server listens to. '
                         help='Port this server listens to. '

+ 132 - 96
src/petals/client/inference_session.py

@@ -3,7 +3,8 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import itertools
 import itertools
 import time
 import time
-from typing import AsyncIterator, List, Optional
+import uuid
+from typing import AsyncIterator, List, Optional, Tuple
 
 
 import torch
 import torch
 from hivemind import (
 from hivemind import (
@@ -15,10 +16,10 @@ from hivemind import (
     serialize_torch_tensor,
     serialize_torch_tensor,
 )
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import StubBase
+from hivemind.p2p import P2P
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
-from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
+from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
 from petals.utils.misc import DUMMY, is_dummy
@@ -35,35 +36,48 @@ class _ServerInferenceSession:
 
 
     def __init__(
     def __init__(
         self,
         self,
+        config: SequenceManagerConfig,
+        span: RemoteSpanInfo,
         uid: ModuleUID,
         uid: ModuleUID,
         rpc_info: RPCInfo,
         rpc_info: RPCInfo,
         inputs_queue: asyncio.Queue,
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
         outputs_aiter: AsyncIterator,
         *,
         *,
-        timeout: float,
         max_length: int,
         max_length: int,
         **metadata,
         **metadata,
     ):
     ):
-        self.uid, self.rpc_info = uid, rpc_info
+        self.config = config
+        self.span, self.uid, self.rpc_info = span, uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.timeout = timeout
-        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
+        self.session_id = str(uuid.uuid4())
+        self.session_metadata = dict(max_length=max_length, **metadata)
         self.stepped = False
         self.stepped = False
         self.closed = False
         self.closed = False
 
 
+        self._position = 0
+        self.history = None  # Used in case of server failures to regenerate attention caches on new servers
+        self.next_session = None
+
     @classmethod
     @classmethod
     async def create(
     async def create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
+        cls,
+        config: SequenceManagerConfig,
+        p2p: P2P,
+        span: RemoteSpanInfo,
+        uid: ModuleUID,
+        rpc_info: RPCInfo,
+        **metadata,
     ) -> _ServerInferenceSession:
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
         inputs_queue = asyncio.Queue()
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
         outputs_stream = await asyncio.wait_for(
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
-            timeout,
+            config.request_timeout,
         )
         )
-        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
+        return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 
 
     @staticmethod
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -75,9 +89,11 @@ class _ServerInferenceSession:
 
 
     def step(
     def step(
         self,
         self,
-        new_hidden_states: torch.Tensor,
+        inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
+        *,
+        step_id: str,
     ) -> torch.Tensor:
     ) -> torch.Tensor:
         """
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         Inference step: send a chunk of input tesors and receive a chunk of outputs
@@ -86,44 +102,84 @@ class _ServerInferenceSession:
         """
         """
         if self.closed:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
             raise Exception("Session is closed, cannot perform step")
+
+        n_input_tokens = inputs.shape[1]
+        if self.history is None:
+            self.history = inputs
+        elif self.history.shape[1] == self._position:
+            self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
+        assert self.history.shape[1] == self._position + n_input_tokens, (
+            f"Broken input cache: span={self.span} shape={self.history.shape} "
+            f"position={self._position} n_input_tokens={n_input_tokens}"
+        )
+
+        if not self.stepped:
+            inputs = self.history  # Pass full inputs including prefix
+        else:
+            inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+
         if prompts is None or is_dummy(prompts):
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
             prompts = DUMMY
         else:
         else:
-            assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.shape[0] == self.num_blocks
             assert prompts.shape[0] == self.num_blocks
-            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
-            assert prompts.shape[2] <= new_hidden_states.shape[1]
-            assert prompts.shape[3] == new_hidden_states.shape[2]
+            assert prompts.shape[1] in (inputs.shape[0], 1)
+            assert prompts.shape[2] <= inputs.shape[1]
+            assert prompts.shape[3] == inputs.shape[2]
 
 
         if hypo_ids is None or is_dummy(hypo_ids):
         if hypo_ids is None or is_dummy(hypo_ids):
             hypo_ids = DUMMY
             hypo_ids = DUMMY
         else:
         else:
-            assert len(hypo_ids) == len(new_hidden_states)
+            assert len(hypo_ids) == len(inputs)
             assert hypo_ids.dtype == torch.int64
             assert hypo_ids.dtype == torch.int64
 
 
         # serialize inputs and put them into the queue
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, prompts, hypo_ids)
+        input_tensors = (inputs, prompts, hypo_ids)
+
+        request_metadata = dict(session_id=self.session_id, step_id=step_id)
+        if not self.stepped:
+            request_metadata.update(self.session_metadata)
+        elif self.config.use_server_to_server:
+            next_servers = self._collect_next_servers()
+            if next_servers:
+                request_metadata["next_servers"] = next_servers
+
         outputs_serialized = RemoteExpertWorker.run_coroutine(
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
             self._step(
                 runtime_pb2.ExpertRequest(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     uid=self.uid,
                     tensors=[
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
+                        for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     ],
-                    metadata=self._serialized_metadata if not self.stepped else None,
+                    metadata=MSGPackSerializer.dumps(request_metadata),
                 )
                 )
             )
             )
         )
         )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        assert (
+            outputs[0].shape == inputs.shape
+        ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
+
+        self._position += n_input_tokens
+
         return outputs[0]
         return outputs[0]
 
 
+    def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
+        next_servers = []
+        session = self.next_session
+        while session is not None and session.stepped:
+            next_servers.append(
+                (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
+            )
+            session = session.next_session
+        return next_servers
+
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
         self.stepped = True
-        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
+        return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
 
 
     def close(self):
     def close(self):
         """Finish a given inference session, close the underlying connection"""
         """Finish a given inference session, close the underlying connection"""
@@ -163,13 +219,15 @@ class InferenceSession:
     def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
     def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
         self._sequence_manager = sequence_manager
         self._sequence_manager = sequence_manager
         self._closed = False
         self._closed = False
-        self._chosen_spans = []
         self._server_sessions = []
         self._server_sessions = []
-        self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
         self._position = 0
         self._max_length = max_length
         self._max_length = max_length
         self.last_token_id = None
         self.last_token_id = None
 
 
+    @property
+    def num_blocks(self) -> int:
+        return len(self._sequence_manager)
+
     @property
     @property
     def position(self) -> int:
     def position(self) -> int:
         return self._position
         return self._position
@@ -178,15 +236,15 @@ class InferenceSession:
         server_sessions = []
         server_sessions = []
         try:
         try:
             for span in chosen_spans:
             for span in chosen_spans:
-                stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id)
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 session = RemoteExpertWorker.run_coroutine(
                 session = RemoteExpertWorker.run_coroutine(
                     _ServerInferenceSession.create(
                     _ServerInferenceSession.create(
-                        stub,
+                        self._sequence_manager.config,
+                        self._sequence_manager.state.p2p,
+                        span,
                         span_uids,
                         span_uids,
                         rpc_info=self._sequence_manager.rpc_info,
                         rpc_info=self._sequence_manager.rpc_info,
-                        timeout=self._sequence_manager.config.request_timeout,
                         max_length=self._max_length,
                         max_length=self._max_length,
                         **metadata,
                         **metadata,
                     )
                     )
@@ -206,7 +264,7 @@ class InferenceSession:
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
 
 
     def __enter__(self) -> "InferenceSession":
     def __enter__(self) -> "InferenceSession":
-        assert not self._closed and not self._chosen_spans
+        assert not self._closed and not self._server_sessions
         return self
         return self
 
 
     def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
     def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
@@ -214,16 +272,17 @@ class InferenceSession:
         if torch.is_grad_enabled():
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
 
 
-        n_blocks = len(self._sequence_manager)
         if prompts is None or is_dummy(prompts):
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
             prompts = DUMMY
         else:
         else:
-            assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+            assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
 
 
         inputs_device = inputs.device
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         inputs = inputs.cpu()
         prompts = prompts.cpu()
         prompts = prompts.cpu()
+        step_id = str(uuid.uuid4())
 
 
         n_input_tokens = inputs.shape[1]
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
         if self._position + n_input_tokens > self._max_length:
@@ -233,97 +292,74 @@ class InferenceSession:
 
 
         server_idx = 0
         server_idx = 0
         block_idx = 0
         block_idx = 0
-        recovery_until = -1  # Recovery mode is disabled until a failure happens
-        while block_idx < n_blocks:
+        while block_idx < self.num_blocks:
             for attempt_no in itertools.count():
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
-                span = None
+                server_session = None
                 try:
                 try:
-                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        # If there is a failed server session, this code closes it
-                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
-
-                        n_prev_spans = len(self._chosen_spans)
-                        update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
-                        if attempt_no >= 1 and update_end > recovery_until:
-                            logger.info(
-                                f"Due to a server failure, remote attention caches "
-                                f"from block {block_idx} to {update_end} will be regenerated"
-                            )
-                        recovery_until = max(recovery_until, update_end)
-
-                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
-                        # make_sequence() could return a longer sequence
-                        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
-                        updated_sessions = self._enter_server_sessions(updated_spans)
-                        logger.debug(
-                            f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
-                        )
-
-                        # If there is a failed span, this code replaces it, otherwise it just adds new ones
-                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
-                        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
-                        recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
-                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
-                            len(updated_spans) - 1
-                        )
-                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
-                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
-                            f"{len(self._server_inputs)} inputs"
-                        )
-
-                    session = self._server_sessions[server_idx]
-                    span = self._chosen_spans[server_idx]
-
-                    if self._server_inputs[server_idx] is None:
-                        self._server_inputs[server_idx] = inputs
-                    elif self._server_inputs[server_idx].shape[1] == self._position:
-                        self._server_inputs[server_idx] = torch.cat(
-                            [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
-                        )
-                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
-                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
-                        f"position={self._position} n_input_tokens={n_input_tokens}"
-                    )
-
-                    if not session.stepped:
-                        inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
-                    else:
-                        inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+                    if not self._server_sessions or attempt_no >= 1:
+                        self._update_sequence(server_idx, block_idx, attempt_no)
 
 
-                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
-                    assert (
-                        inputs.shape == outputs.shape
-                    ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
+                    server_session = self._server_sessions[server_idx]
+                    inputs = server_session.step(
+                        inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
+                    )
 
 
-                    inputs = outputs
                     server_idx += 1
                     server_idx += 1
-                    block_idx = span.end
-                    self._sequence_manager.on_request_success(span.peer_id)
+                    block_idx = server_session.span.end
+                    self._sequence_manager.on_request_success(server_session.span.peer_id)
                     break
                     break
                 except Exception as e:
                 except Exception as e:
-                    self._sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                    self._sequence_manager.on_request_failure(
+                        server_session.span.peer_id if server_session is not None else None
+                    )
                     if attempt_no + 1 == self._sequence_manager.config.max_retries:
                     if attempt_no + 1 == self._sequence_manager.config.max_retries:
                         raise
                         raise
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
                     logger.warning(
-                        f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}"
+                        f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
                     )
                     maybe_log_traceback(e)
                     maybe_log_traceback(e)
                     time.sleep(delay)
                     time.sleep(delay)
 
 
         self._position += n_input_tokens
         self._position += n_input_tokens
-        inputs = inputs[:, -n_input_tokens:]
-        outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+        outputs = inputs[:, -n_input_tokens:]
+        outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
         return outputs
         return outputs
 
 
+    def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
+        # If there is a failed server session, this code closes it
+        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+
+        n_prev_spans = len(self._server_sessions)
+        update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
+        if attempt_no >= 1:
+            logger.info(
+                f"Due to a server failure, remote attention caches "
+                f"from block {block_idx} to {update_end} will be regenerated"
+            )
+
+        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
+        # make_sequence() could return a longer sequence
+        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
+        updated_sessions = self._enter_server_sessions(updated_spans)
+        logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
+
+        # If there is a failed span, this code replaces it, otherwise it just adds new ones
+        if server_idx < n_prev_spans:
+            updated_sessions[0].history = self._server_sessions[server_idx].history
+        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
+
+        # Update links to the next server session for direct server-to-server communication via rpc_push()
+        for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
+            self._server_sessions[i].next_session = self._server_sessions[i + 1]
+
     def close(self, *exc_details):
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
         """Finish a given inference session, close the underlying connection"""
         if not self._closed:
         if not self._closed:
-            self._server_inputs.clear()
             self._exit_server_sessions(self._server_sessions)
             self._exit_server_sessions(self._server_sessions)
             self._server_sessions.clear()
             self._server_sessions.clear()
-            self._chosen_spans.clear()
             self._closed = True
             self._closed = True
 
 
     def __exit__(self, *exc_details):
     def __exit__(self, *exc_details):

+ 1 - 0
src/petals/client/routing/sequence_manager.py

@@ -34,6 +34,7 @@ class SequenceManagerConfig:
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
 
 
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+    use_server_to_server: bool = True  # Use direct server-to-server communication
 
 
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     update_period: float = 60  # refresh DHT information once in this many seconds
     update_period: float = 60  # refresh DHT information once in this many seconds

+ 177 - 23
src/petals/server/handler.py

@@ -2,6 +2,9 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import contextlib
 import contextlib
+import multiprocessing.managers
+import sys
+from concurrent.futures import ThreadPoolExecutor
 from itertools import chain
 from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
@@ -11,6 +14,7 @@ from hivemind import (
     DHT,
     DHT,
     MSGPackSerializer,
     MSGPackSerializer,
     P2PContext,
     P2PContext,
+    PeerID,
     deserialize_tensor_stream,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     deserialize_torch_tensor,
     nested_flatten,
     nested_flatten,
@@ -25,7 +29,7 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.streaming import split_for_streaming
 
 
 import petals
 import petals
-from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
 from petals.server.backend import TransformerBackend
 from petals.server.backend import TransformerBackend
 from petals.server.memory_cache import Handle
 from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
@@ -34,6 +38,23 @@ from petals.utils.misc import DUMMY, is_dummy
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
+
+# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
+sys.modules["runtime_pb2"] = runtime_pb2
+
+# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
+
+_OriginalAutoProxy = multiprocessing.managers.AutoProxy
+
+
+def patched_autoproxy(*args, manager_owned=True, **kwargs):
+    # Calling original AutoProxy without the unwanted key argument
+    return _OriginalAutoProxy(*args, **kwargs)
+
+
+multiprocessing.managers.AutoProxy = patched_autoproxy
+
+
 CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
 CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
 
 
 
 
@@ -47,6 +68,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         dht: DHT,
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         module_backends: Dict[str, TransformerBackend],
         *,
         *,
+        dht_prefix: str,
+        push_manager: multiprocessing.managers.SyncManager,
+        session_queues: Dict[str, multiprocessing.managers.BaseProxy],  # BaseProxy for queue.Queue
         inference_max_length: int,
         inference_max_length: int,
         request_timeout: float,
         request_timeout: float,
         session_timeout: float,
         session_timeout: float,
@@ -56,6 +80,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         super().__init__(dht, module_backends)
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
             assert isinstance(module_backend, TransformerBackend)
+        self.dht_prefix = dht_prefix
+        self._push_manager = push_manager
+        self._session_queues = session_queues
+        self._executor = ThreadPoolExecutor(max_workers=float("inf"))  # For waiting on self.session_queues
+
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
         self.request_timeout = request_timeout
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
@@ -96,7 +125,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         self,
         self,
         requests: AsyncIterator[runtime_pb2.ExpertRequest],
         requests: AsyncIterator[runtime_pb2.ExpertRequest],
         context: P2PContext,
         context: P2PContext,
-    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
 
 
         async with timeout(self.session_timeout):
         async with timeout(self.session_timeout):
@@ -113,6 +142,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 points = metadata.get("points", 0)
+                session_id = metadata.get("session_id")
 
 
                 if not requested_uids:
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                     raise ValueError("User must specify at least one block for inference, but got none")
@@ -133,7 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
                     assert len(cache_handles) == len(requested_backends)
                     assert len(cache_handles) == len(requested_backends)
-                    while request.tensors:  # iterate while user is willing to supply tensors
+                    first_request = request
+                    background_tasks = set()
+                    async for request, metadata in self._iterate_inference_steps(
+                        first_request, requests, session_id, requested_uids, context
+                    ):
                         hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
                         hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
 
                         # Cast inputs to backend dtype
                         # Cast inputs to backend dtype
@@ -141,7 +175,8 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
                         assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
 
                         # parse deep prompts (optional argument)
                         # parse deep prompts (optional argument)
-                        if prompts is None or is_dummy(prompts):
+                        has_prompts = prompts is not None and not is_dummy(prompts)
+                        if not has_prompts:
                             prompts = [None] * len(requested_backends)
                             prompts = [None] * len(requested_backends)
                         else:
                         else:
                             prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
                             prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
@@ -180,25 +215,136 @@ class TransformerConnectionHandler(ConnectionHandler):
                             )
                             )
 
 
                         # serialize and send last layer outputs
                         # serialize and send last layer outputs
-                        yield runtime_pb2.ExpertResponse(
-                            tensors=[
-                                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                                for result, proto in zip(
-                                    (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                                )
-                            ]
-                        )
+                        output_tensors = [
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                            for result, proto in zip(
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
+                            )
+                        ]
+                        if not has_prompts:
+                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
+                            background_tasks.add(task)  # Keep reference until it is done to save it from GC
+                            task.add_done_callback(background_tasks.discard)
+                        yield runtime_pb2.ExpertResponse(tensors=output_tensors)
 
 
                         # prepare for next step
                         # prepare for next step
-                        prefix_length += hidden_states.shape[1]
-                        try:
-                            request = await asyncio.wait_for(anext(requests), self.step_timeout)
-                        except asyncio.TimeoutError:
-                            self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
-                            return
+                        prefix_length += length_increment
             finally:
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
                 self._log_request("rpc_inference.close", requested_uids, context)
 
 
+    async def _iterate_inference_steps(
+        self,
+        first_request: runtime_pb2.ExpertRequest,
+        requests: AsyncIterator[runtime_pb2.ExpertRequest],
+        session_id: Optional[str],
+        requested_uids: Sequence[str],
+        context: P2PContext,
+    ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
+        loop = asyncio.get_event_loop()
+        if session_id is not None:
+            push_queue = self._push_manager.Queue()
+            self._session_queues[session_id] = push_queue
+
+        processed_step_ids = set()
+        n_pushes = n_late_pushes = 0
+        request = first_request
+        anext_task = get_push_task = None
+        try:
+            while request.tensors:  # iterate while user is willing to supply tensors
+                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                step_id = metadata.get("step_id")
+
+                pushed = metadata.get("pushed")
+                if pushed:
+                    n_pushes += 1
+
+                if step_id is None or step_id not in processed_step_ids:
+                    yield request, metadata
+                    if step_id is not None:
+                        processed_step_ids.add(step_id)
+                elif pushed:
+                    n_late_pushes += 1
+                    self._log_request(
+                        "rpc_inference.push",
+                        requested_uids,
+                        context,
+                        warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+                    )
+
+                # Wait for the next request, coming either from the `requests` iterator or `push_queue`
+                if anext_task is None:
+                    anext_task = asyncio.create_task(anext(requests))
+                if get_push_task is None:
+                    if session_id is not None:
+                        get_push_task = loop.run_in_executor(self._executor, push_queue.get)
+                    else:
+                        get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
+                done, _ = await asyncio.wait(
+                    [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
+                )
+
+                if anext_task in done:
+                    request = await anext_task
+                    anext_task = None
+                elif get_push_task in done:
+                    request = await get_push_task
+                    get_push_task = None
+                else:
+                    self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
+                    anext_task.cancel()
+                    get_push_task.cancel()
+                    return
+        except:
+            logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
+            raise
+        finally:
+            if session_id is not None:
+                push_queue.put(None)  # Stop thread for get_push_task
+                del self._session_queues[session_id]
+
+    async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        """Directly push activation tensors from one server to another"""
+
+        requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_push", requested_uids, context)
+
+        metadata = MSGPackSerializer.loads(request.metadata)
+        session_id = metadata["session_id"]
+        self._session_queues[session_id].put(request)
+        return runtime_pb2.ExpertResponse()
+
+    async def _push_outputs(
+        self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
+    ) -> None:
+        try:
+            next_servers = metadata.get("next_servers")
+            if not next_servers:
+                return
+
+            next_peer_id, next_session_id, next_start, next_end = next_servers[0]
+            next_peer_id = PeerID.from_base58(next_peer_id)
+            next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end))
+
+            # Sending hidden states serialized with output_schema to avoid double serialization
+            next_tensors = [serialized_outputs] + request.tensors[1:]
+            next_metadata = metadata.copy()
+            next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)
+
+            stub = self.get_stub(self._p2p, next_peer_id)
+            await stub.rpc_push(
+                runtime_pb2.ExpertRequest(
+                    uid=next_uid,
+                    tensors=next_tensors,
+                    metadata=MSGPackSerializer.dumps(next_metadata),
+                ),
+                timeout=self.request_timeout,
+            )
+        except Exception:
+            logger.debug(
+                f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:",
+                exc_info=True,
+            )
+
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         async with timeout(self.request_timeout):
         async with timeout(self.request_timeout):
             # Parse request and prepare backends
             # Parse request and prepare backends
@@ -348,7 +494,7 @@ class TransformerConnectionHandler(ConnectionHandler):
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def _allocate_cache(
     async def _allocate_cache(
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
-    ) -> Sequence[Sequence[Handle, ...]]:
+    ) -> Sequence[Sequence[Handle]]:
         """
         """
         Allocate memory cache for all transformer blocks, return cache handle
         Allocate memory cache for all transformer blocks, return cache handle
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
@@ -358,7 +504,13 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield nested_pack(handles, descriptors)
             yield nested_pack(handles, descriptors)
 
 
     def _log_request(
     def _log_request(
-        self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
+        self,
+        method: str,
+        uids: Optional[Sequence[ModuleUID]],
+        context: P2PContext,
+        *,
+        debug: Optional[str] = None,
+        warning: Optional[str] = None,
     ) -> None:
     ) -> None:
         if uids is not None:
         if uids is not None:
             friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
             friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
@@ -370,10 +522,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         friendly_remote_id = "..." + str(context.remote_id)[-6:]
         friendly_remote_id = "..." + str(context.remote_id)[-6:]
 
 
         message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
         message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
-        if warning is None:
-            logger.info(message)
-        else:
+        if warning is not None:
             logger.warning(f"{message}: {warning}")
             logger.warning(f"{message}: {warning}")
+        elif debug is not None:
+            logger.debug(f"{message}: {debug}")
+        else:
+            logger.info(message)
 
 
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         """Return metadata about stored block uids and current load"""
         """Return metadata about stored block uids and current load"""

+ 22 - 10
src/petals/server/server.py

@@ -45,7 +45,7 @@ class Server:
         self,
         self,
         *,
         *,
         initial_peers: List[str],
         initial_peers: List[str],
-        prefix: Optional[str],
+        dht_prefix: Optional[str],
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
         throughput: Union[float, str],
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         num_blocks: Optional[int] = None,
@@ -105,13 +105,13 @@ class Server:
             revision=revision,
             revision=revision,
         )
         )
 
 
-        if prefix is None:
-            prefix = self.block_config.dht_prefix
-        assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+        if dht_prefix is None:
+            dht_prefix = self.block_config.dht_prefix
+        assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (
             f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
             f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
-            f"Please specify another --prefix manually when starting a server"
+            f"Please specify another --dht_prefix manually when starting a server"
         )
         )
-        self.prefix = prefix
+        self.dht_prefix = dht_prefix
 
 
         if expiration is None:
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
@@ -121,7 +121,8 @@ class Server:
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
 
         self.module_uids = [
         self.module_uids = [
-            f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
+            f"{self.dht_prefix}{UID_DELIMITER}{block_index}"
+            for block_index in range(self.block_config.num_hidden_layers)
         ]
         ]
 
 
         if dht_client_mode is None:
         if dht_client_mode is None:
@@ -258,7 +259,7 @@ class Server:
             block_indices = self._choose_blocks()
             block_indices = self._choose_blocks()
             self.module_container = ModuleContainer.create(
             self.module_container = ModuleContainer.create(
                 dht=self.dht,
                 dht=self.dht,
-                prefix=self.prefix,
+                dht_prefix=self.dht_prefix,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
                 attn_cache_bytes=self.attn_cache_bytes,
@@ -359,7 +360,7 @@ class ModuleContainer(threading.Thread):
         cls,
         cls,
         *,
         *,
         dht: DHT,
         dht: DHT,
-        prefix: str,
+        dht_prefix: str,
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
         attn_cache_bytes: int,
@@ -382,7 +383,7 @@ class ModuleContainer(threading.Thread):
         should_validate_reachability: bool,
         should_validate_reachability: bool,
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
-        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         joining_announcer = ModuleAnnouncerThread(
         joining_announcer = ModuleAnnouncerThread(
             module_uids,
             module_uids,
             dht,
             dht,
@@ -459,6 +460,7 @@ class ModuleContainer(threading.Thread):
 
 
         return cls(
         return cls(
             dht,
             dht,
+            dht_prefix,
             blocks,
             blocks,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
@@ -469,6 +471,7 @@ class ModuleContainer(threading.Thread):
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
+        dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         module_backends: Dict[str, TransformerBackend],
         *,
         *,
         inference_max_length: int,
         inference_max_length: int,
@@ -486,10 +489,17 @@ class ModuleContainer(threading.Thread):
 
 
         self.dht, self.module_backends = dht, module_backends
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+
+        self.push_manager = mp.Manager()
+        self.push_manager.__enter__()
+        session_queues = self.push_manager.dict()
         self.conn_handlers = [
         self.conn_handlers = [
             TransformerConnectionHandler(
             TransformerConnectionHandler(
                 dht,
                 dht,
                 self.module_backends,
                 self.module_backends,
+                dht_prefix=dht_prefix,
+                push_manager=self.push_manager,
+                session_queues=session_queues,
                 inference_max_length=inference_max_length,
                 inference_max_length=inference_max_length,
                 request_timeout=request_timeout,
                 request_timeout=request_timeout,
                 session_timeout=session_timeout,
                 session_timeout=session_timeout,
@@ -497,6 +507,7 @@ class ModuleContainer(threading.Thread):
             )
             )
             for _ in range(num_handlers)
             for _ in range(num_handlers)
         ]
         ]
+
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
         self.online_announcer = ModuleAnnouncerThread(
@@ -577,6 +588,7 @@ class ModuleContainer(threading.Thread):
         logger.debug("Shutting down connection handlers")
         logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
         for handler in self.conn_handlers:
             handler.shutdown()
             handler.shutdown()
+        self.push_manager.__exit__(None, None, None)
 
 
         logger.debug(f"Shutting down pools")
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools:
         for pool in self.runtime.pools: