Răsfoiți Sursa

probably break everyting

Your Name 1 an în urmă
părinte
comite
a23bd73f3b

+ 24 - 8
src/petals/client/inference_session.py

@@ -4,7 +4,7 @@ import asyncio
 import itertools
 import time
 import uuid
-from typing import AsyncIterator, List, Optional, Sequence, Tuple
+from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
@@ -43,7 +43,7 @@ class _ServerInferenceSession:
         **metadata,
     ):
         self.config = config
-        self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
+        self.span, self.span_uids = span, span_uids
         self.num_blocks = len(span_uids)
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
@@ -67,7 +67,6 @@ class _ServerInferenceSession:
         **metadata,
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        # TODO YOZH you don't need rpc info here
         stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
@@ -89,7 +88,7 @@ class _ServerInferenceSession:
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-        *,
+        *block_kwargs: Dict[str, Any],
         step_id: str,
     ) -> torch.Tensor:
         """
@@ -97,6 +96,7 @@ class _ServerInferenceSession:
         :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
           if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
         """
+        # TODO record previous kwargs in case of server failure!!!
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
@@ -115,6 +115,7 @@ class _ServerInferenceSession:
         else:
             inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
+        assert len(block_kwargs) in (0, self.span.length)
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
@@ -131,7 +132,7 @@ class _ServerInferenceSession:
             assert hypo_ids.dtype == torch.int64
 
         # serialize inputs and put them into the queue
-        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
+        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
 
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
@@ -141,7 +142,7 @@ class _ServerInferenceSession:
             if next_servers:
                 request_metadata["next_servers"] = next_servers
 
-        request_metadata["args_structure"] = args_structure
+        args_structure = request_metadata.setdefault("args_structure", args_structure)
 
         # TODO YOZH FIX THIS BEFORE THE END OF THIS PR
         # TODO: make possible to use different compression method for different tensors
@@ -277,11 +278,22 @@ class InferenceSession:
         assert not self._closed and not self._server_sessions
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+    def step(
+        self,
+        inputs: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        *block_kwargs: Sequence[Dict[str, torch.Tensor]],
+        **kwargs,
+    ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
 
+        num_blocks = len(self._sequence_manager)
+        if len(block_kwargs) == 1:
+            block_kwargs = block_kwargs * num_blocks
+        assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
+
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
@@ -312,7 +324,11 @@ class InferenceSession:
 
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
-                        inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
+                        inputs,
+                        prompts[server_session.span.start : server_session.span.end],
+                        *block_kwargs[server_session.span.start : server_session.span.end],
+                        step_id=step_id,
+                        **kwargs,
                     )
 
                     server_idx += 1

+ 5 - 3
src/petals/client/sequential_autograd.py

@@ -52,7 +52,7 @@ async def sequential_forward(
     if len(block_kwargs) == 1:
         block_kwargs = block_kwargs * (end_index - start_index)
     assert (
-        len(block_kwargs) in (0, end_index - start_index)
+        not block_kwargs or len(block_kwargs) == end_index - start_index
     ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
@@ -222,7 +222,8 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor):
+        # TODO add kwargs here; figure out a way to split kwargs across servers
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
         input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches)
@@ -271,4 +272,5 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
         grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
-        return (grad_inputs, grad_prompts, None)
+        # TODO return grads w.r.t. kwargs here
+        return (None, grad_inputs, grad_prompts)