Browse Source

temporary rollback: allow kwargs only at first inference step

Your Name 1 year ago
parent
commit
3f06b53b1d
1 changed files with 69 additions and 64 deletions
  1. 69 64
      src/petals/client/inference_session.py

+ 69 - 64
src/petals/client/inference_session.py

@@ -7,15 +7,13 @@ import uuid
 from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P
 from hivemind.proto import runtime_pb2
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
 
-from petals.client.config import ClientConfig
 from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
-from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
 from petals.utils.packaging import pack_args_kwargs
@@ -32,23 +30,21 @@ class _ServerInferenceSession:
 
     def __init__(
         self,
-        config: ClientConfig,
+        sequence_manager: RemoteSequenceManager,
         span: RemoteSpanInfo,
         span_uids: Sequence[ModuleUID],
-        rpc_info: RPCInfo,
         inputs_queue: asyncio.Queue,
-        outputs_aiter: AsyncIterator,
-        *,
+        outputs_stream: AsyncIterator,
+        *block_kwargs,
         max_length: int,
-        **metadata,
     ):
-        self.config = config
-        self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
+        self.sequence_manager = sequence_manager
+        self.span, self.span_uids = span, span_uids
         self.num_blocks = len(span_uids)
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
         self.session_id = str(uuid.uuid4())
-        self.session_metadata = dict(max_length=max_length, **metadata)
+        self.max_length = max_length
         self.stepped = False
         self.closed = False
 
@@ -56,24 +52,26 @@ class _ServerInferenceSession:
         self.history = None  # Used in case of server failures to regenerate attention caches on new servers
         self.next_session = None
 
+        self.block_kwargs = block_kwargs
+        assert len(self.block_kwargs) in (0, self.num_blocks)
+
     @classmethod
     async def create(
         cls,
-        config: ClientConfig,
-        p2p: P2P,
+        sequence_manager: RemoteSequenceManager,
         span: RemoteSpanInfo,
-        span_uids: Sequence[RemoteSpanInfo],
-        rpc_info: RPCInfo,
-        **metadata,
+        span_uids: Sequence[ModuleUID],
+        *block_kwargs: Dict[str, Any],
+        **kwargs,
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
+        stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
-            config.connect_timeout,
+            sequence_manager.config.connect_timeout,
         )
-        return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
+        return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -87,7 +85,7 @@ class _ServerInferenceSession:
         self,
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
-        *block_kwargs: Dict[str, Any],
+        *,
         hypo_ids: Optional[torch.Tensor] = None,
         step_id: str,
     ) -> torch.Tensor:
@@ -96,7 +94,6 @@ class _ServerInferenceSession:
         :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
           if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
         """
-        # TODO record previous kwargs in case of server failure!!!
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
@@ -112,10 +109,11 @@ class _ServerInferenceSession:
 
         if not self.stepped:
             inputs = self.history  # Pass full inputs including prefix
+            block_kwargs = self.block_kwargs
         else:
             inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+            block_kwargs = []
 
-        assert len(block_kwargs) in (0, self.span.length)
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
@@ -131,39 +129,50 @@ class _ServerInferenceSession:
             assert len(hypo_ids) == len(inputs)
             assert hypo_ids.dtype == torch.int64
 
-        # serialize inputs and put them into the queue
-        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
-
-        request_metadata = dict(session_id=self.session_id, step_id=step_id)
-        if not self.stepped:
-            request_metadata.update(self.session_metadata)
-        elif self.config.use_server_to_server:
+        metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
+        metadata.update(
+            self.sequence_manager.get_request_metadata(
+                self.span.peer_id,
+                "rpc_inference",
+                self.span_uids,
+                inputs,
+                prompts,
+                *block_kwargs,
+                max_length=self.max_length,
+                session_id=self.session_id,
+                step_id=step_id,
+            )
+        )
+        if self.stepped and self.sequence_manager.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
-                request_metadata["next_servers"] = next_servers
+                metadata["next_servers"] = next_servers
 
-        args_structure = request_metadata.setdefault("args_structure", args_structure)
+        codecs = self.sequence_manager.get_compression_codecs(
+            self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
+        )
 
-        # TODO YOZH FIX THIS BEFORE THE END OF THIS PR
-        # TODO: make possible to use different compression method for different tensors
-        server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
-        compression = server_side_inference_schema[0].compression
-        inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
+        # serialize inputs and put them into the queue
+        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
+        args_structure = metadata.setdefault("args_structure", args_structure)
 
-        # TODO: create more explicit way to check servers schema and client's structure
-        assert len(input_tensors) >= len(
-            server_side_inference_schema
-        ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
+        if codecs is None:
+            codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
+        else:
+            codecs = list(nested_flatten(codecs))
+            assert len(codecs) == len(
+                input_tensors
+            ), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
 
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=CHAIN_DELIMITER.join(self.span_uids),
                     tensors=[
-                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(input_tensors, inference_schema)
+                        serialize_torch_tensor(tensor, compression)
+                        for tensor, compression in zip(input_tensors, codecs)
                     ],
-                    metadata=MSGPackSerializer.dumps(request_metadata),
+                    metadata=MSGPackSerializer.dumps(metadata),
                 )
             )
         )
@@ -190,7 +199,7 @@ class _ServerInferenceSession:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
-        return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
+        return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
 
     def close(self):
         """Finish a given inference session, close the underlying connection"""
@@ -227,7 +236,7 @@ class InferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
+    def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
         self._sequence_manager = sequence_manager
         self._closed = False
         self._server_sessions = []
@@ -235,6 +244,12 @@ class InferenceSession:
         self._max_length = max_length
         self.output_ids = None
 
+        num_blocks = len(self._sequence_manager)
+        if len(block_kwargs) == 1:
+            block_kwargs = block_kwargs * num_blocks
+        assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
+        self.block_kwargs = block_kwargs
+
     @property
     def num_blocks(self) -> int:
         return len(self._sequence_manager)
@@ -247,17 +262,13 @@ class InferenceSession:
         server_sessions = []
         try:
             for span in chosen_spans:
-                span_uids = self._sequence_manager.block_uids[span.start : span.end]
-                metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
                 session = RemoteExpertWorker.run_coroutine(
                     _ServerInferenceSession.create(
-                        self._sequence_manager.config,
-                        self._sequence_manager.state.p2p,
+                        self._sequence_manager,
                         span,
-                        span_uids,
-                        rpc_info=self._sequence_manager.rpc_info,  # TODO not actually needed
+                        self._sequence_manager.block_uids[span.start : span.end],
+                        *self.block_kwargs[span.start : span.end],
                         max_length=self._max_length,
-                        **metadata,
                     )
                 )
                 server_sessions.append(session)
@@ -282,18 +293,13 @@ class InferenceSession:
         self,
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
-        *block_kwargs: Sequence[Dict[str, torch.Tensor]],
         hypo_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
+
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
 
-        num_blocks = len(self._sequence_manager)
-        if len(block_kwargs) == 1:
-            block_kwargs = block_kwargs * num_blocks
-        assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
-
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
@@ -326,9 +332,8 @@ class InferenceSession:
                     inputs = server_session.step(
                         inputs,
                         prompts[server_session.span.start : server_session.span.end],
-                        *block_kwargs[server_session.span.start : server_session.span.end],
-                        step_id=step_id,
                         hypo_ids=hypo_ids,
+                        step_id=step_id,
                     )
 
                     server_idx += 1
@@ -354,7 +359,7 @@ class InferenceSession:
         outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
         return outputs
 
-    def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
+    def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
         # If there is a failed server session, this code closes it
         self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])