Forráskód Böngészése

Deep distributed prompt tuning (#42)

* implemented an option to add learnable prompts to intermediate layers
* added support for prompts (as input) in rpc_forward and rpc_backward
* added a test to check that RemoteSequential works correctly with deep prompts

Co-authored-by: justheuristic <justheuristic@gmail.com>
Dmitry Baranchuk 3 éve
szülő
commit
6095f58681

+ 51 - 61
src/client/remote_model.py

@@ -1,5 +1,5 @@
 # this code is in active development, interfaces may change
 # this code is in active development, interfaces may change
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple
 
 
 import hivemind
 import hivemind
 import torch
 import torch
@@ -17,6 +17,7 @@ from src.bloom.model import (
 )
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
+from src.utils.misc import DUMMY
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -33,6 +34,7 @@ class DistributedBloomConfig(BloomConfig):
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
 
 
 
 
 class DistributedBloomModel(BloomModel):
 class DistributedBloomModel(BloomModel):
@@ -60,10 +62,41 @@ class DistributedBloomModel(BloomModel):
         # Forbid accumulate grads for embeddings and layernorm
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
         self.set_requires_grad(False)
 
 
+        if config.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            if config.tuning_mode == "deep_ptune":
+                self.intermediate_prompt_embeddings = nn.Embedding(
+                    self.pre_seq_len,
+                    config.num_hidden_layers * config.hidden_size
+                    # ^-- TODO: should be num_hidden_layers - 1
+                )
+                self.intermediate_prompt_embeddings.weight.data.zero_()
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
     def set_requires_grad(self, value):
     def set_requires_grad(self, value):
         for p in self.parameters():
         for p in self.parameters():
             p.requires_grad = value
             p.requires_grad = value
 
 
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.config.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+        return prompts, intermediate_prompts
+
     def forward(
     def forward(
         self,
         self,
         input_ids: Optional[torch.LongTensor] = None,
         input_ids: Optional[torch.LongTensor] = None,
@@ -90,10 +123,22 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
             inputs_embeds = self.word_embeddings(input_ids)
 
 
-        # Note: it supports only float32 or bfloat16 inputs
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
         output_shape = input_shape + (hidden_states.size(-1),)
         output_shape = input_shape + (hidden_states.size(-1),)
-        hidden_states = self.h(hidden_states)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
 
 
         # Add last hidden state
         # Add last hidden state
         hidden_states = self.ln_f(hidden_states)
         hidden_states = self.ln_f(hidden_states)
@@ -106,55 +151,6 @@ class DistributedBloomModel(BloomModel):
         )
         )
 
 
 
 
-class DistributedBloomPrefix(DistributedBloomModel):
-    """DistributedBloomModel with prefix tokens for prompt tuning"""
-
-    def __init__(self, config):
-        super().__init__(config)
-        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-        self.pre_seq_len = config.pre_seq_len
-
-        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
-        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-        return prompts
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
-    ):
-        assert (
-            input_ids is None or inputs_embeds is None
-        ), "You cannot specify both input_ids and inputs_embeds at the same time"
-        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        batch_size = inputs_embeds.shape[0]
-
-        if attention_mask is not None:
-            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
-            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
-
-        prompts = self.get_prompt(batch_size)
-        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-
-        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
-
-        # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
-        transformer_outputs["last_hidden_state"] = last_hidden_state
-        return transformer_outputs
-
-
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
 
@@ -162,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
         BloomPreTrainedModel.__init__(self, config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
         self.lm_head = LMHead(config, self.transformer.word_embeddings)
         self.lm_head = LMHead(config, self.transformer.word_embeddings)
 
 
         # Initialize weights and apply final processing
         # Initialize weights and apply final processing
@@ -195,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
         super().__init__(config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
 
 
         # Initialize weights and apply final processing
         # Initialize weights and apply final processing
         self.post_init()
         self.post_init()

+ 3 - 2
src/client/remote_sequential.py

@@ -15,6 +15,7 @@ from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 from src.dht_utils import _create_remote_modules_from_infos
+from src.utils.misc import DUMMY
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -52,8 +53,8 @@ class RemoteSequential(nn.Module):
             assert isinstance(sequence_manager.block_uids, list)
             assert isinstance(sequence_manager.block_uids, list)
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
 
-    def forward(self, inputs: torch.Tensor):
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
+    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
         return outputs
 
 
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:

+ 83 - 35
src/client/sequential_autograd.py

@@ -12,6 +12,7 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 
 MAX_TOKENS_IN_BATCH = 1024
 MAX_TOKENS_IN_BATCH = 1024
 
 
@@ -33,7 +34,13 @@ async def run_expert_forward(
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
     forward_inputs = (inputs, kwargs)
 
 
-    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
     forward_inputs = nested_flatten(forward_inputs)
     forward_inputs = nested_flatten(forward_inputs)
@@ -44,7 +51,7 @@ async def run_expert_forward(
     serialized_tensors = await asyncio.gather(
     serialized_tensors = await asyncio.gather(
         *(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
         )
         )
     )
     )
 
 
@@ -57,8 +64,9 @@ async def run_expert_backward(
     uid: ModuleUID,
     uid: ModuleUID,
     stub: StubBase,
     stub: StubBase,
     rpc_info: RPCInfo,
     rpc_info: RPCInfo,
-    intemediate_inputs: List[torch.Tensor],
+    inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
 ) -> Sequence[torch.Tensor]:
     """
     """
     Serializes grad outputs and calls "expert_backward".
     Serializes grad outputs and calls "expert_backward".
@@ -67,8 +75,14 @@ async def run_expert_backward(
     """
     """
 
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
 
 
     # Asynchronous serialization
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()
@@ -84,7 +98,11 @@ async def run_expert_backward(
 
 
 
 
 async def sequential_forward(
 async def sequential_forward(
-    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
+    inputs: torch.Tensor,
+    prompts: torch.Tensor,
+    sequence_manager: RemoteSequenceManager,
+    start_index: int = 0,
+    end_index: Optional[int] = None,
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     """
     Constructs a routing path from <start_index> to <end_index>.
     Constructs a routing path from <start_index> to <end_index>.
@@ -96,6 +114,9 @@ async def sequential_forward(
 
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+    assert is_dummy(prompts) or len(prompts) == len(
+        sequence_manager.block_uids
+    )  # should be n_layers - 1 but add extra prompts for convenience
 
 
     sequences = sequence_manager.make_sequence(start_index, end_index)
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
     intermediate_inputs = []
@@ -107,7 +128,9 @@ async def sequential_forward(
                 span = sequences.pop(0)
                 span = sequences.pop(0)
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
 
 
                 assert isinstance(outputs, torch.Tensor)
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -119,7 +142,7 @@ async def sequential_forward(
                 inputs = outputs
                 inputs = outputs
                 break
                 break
             except Exception as e:
             except Exception as e:
-                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
                 assert backup_sequences[0].start == span.start
                 sequences = backup_sequences
                 sequences = backup_sequences
@@ -129,58 +152,68 @@ async def sequential_forward(
 
 
 async def sequential_backward(
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     grad_outputs: Sequence[torch.Tensor],
-    intermediate_inputs: Sequence[torch.Tensor],
-    forward_sequences: Sequence[RemoteSpanInfo],
+    intermediate_inputs: List[torch.Tensor],
+    prompts: torch.Tensor,
+    forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
     sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
 ) -> Sequence[torch.Tensor]:
     """
     """
     Performs chained backward for each forward subsequence.
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
     """
-
     assert len(intermediate_inputs) == len(forward_sequences)
     assert len(intermediate_inputs) == len(forward_sequences)
-    # TODO think about grads w.r.t. deep prompts
 
 
+    grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
         while True:
+            inputs = intermediate_inputs.pop(-1)
+            span = forward_sequences.pop(-1)
+            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
             try:
-                inputs = intermediate_inputs.pop(-1)
-                span = forward_sequences.pop(-1)
-
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-
-                grad_outputs = await run_expert_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                grad_outputs, *span_grad_prompts = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
                 )
                 )
+                grad_outputs = [grad_outputs]
+                grad_prompts_reversed.extend(span_grad_prompts)
                 break
                 break
             except Exception as e:
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
-                    inputs, sequence_manager, start_index=span.start, end_index=span.end
+                    inputs, prompts[span.start : span.end], sequence_manager, start_index=span.start, end_index=span.end
                 )
                 )
-
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert backup_forward_sequences[0].start == span.start
                 assert backup_forward_sequences[0].start == span.start
                 assert backup_forward_sequences[-1].end == span.end
                 assert backup_forward_sequences[-1].end == span.end
 
 
                 forward_sequences.extend(backup_forward_sequences)
                 forward_sequences.extend(backup_forward_sequences)
                 intermediate_inputs.extend(backup_intermediate_inputs)
                 intermediate_inputs.extend(backup_intermediate_inputs)
-    return grad_outputs
+
+    # For now, we do not support mixed dummy and grad prompts
+    # Concat in num_layer dimension
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+    return grad_outputs, grad_prompts
 
 
 
 
-async def _gather_forward(input_batches, sequence_manager):
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
-    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
+    return await asyncio.gather(
+        *[
+            sequential_forward(input_batch, prompt_batch, sequence_manager)
+            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+        ]
+    )
 
 
 
 
-async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+async def _gather_backward(
+    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
     return await asyncio.gather(
         *[
         *[
-            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
-            for grad_output, input_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, forward_sequences
+            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
             )
             )
         ]
         ]
     )
     )
@@ -193,18 +226,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
     """
 
 
     @staticmethod
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        if is_dummy(prompts):
+            prompt_batches = [DUMMY] * len(input_batches)
+        else:
+            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
 
         sequence_manager.rpc_info  # lazy init
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
         assert len(outputs) == len(input_batches)
 
 
         output_batches = [output[0] for output in outputs]
         output_batches = [output[0] for output in outputs]
         intemediate_input_batches = [output[1] for output in outputs]
         intemediate_input_batches = [output[1] for output in outputs]
         sequences_for_batches = [output[2] for output in outputs]
         sequences_for_batches = [output[2] for output in outputs]
 
 
+        ctx.prompt_batches = prompt_batches
         ctx.sequence_manager = sequence_manager
         ctx.sequence_manager = sequence_manager
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.sequences_for_batches = sequences_for_batches
         ctx.sequences_for_batches = sequences_for_batches
@@ -220,9 +258,19 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
         assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
         assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
 
 
-        grad_input_batches = RemoteExpertWorker.run_coroutine(
-            _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
+        outputs = RemoteExpertWorker.run_coroutine(
+            _gather_backward(
+                grad_output_batches,
+                intermediate_input_batches,
+                ctx.prompt_batches,
+                forward_sequences,
+                ctx.sequence_manager,
+            )
         )
         )
-        grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
-        grad_inputs = torch.cat(grad_inputs, dim=0)
-        return (grad_inputs, None)
+        grad_input_batches = [output[0][0] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
+
+        grad_inputs = torch.cat(grad_input_batches, dim=0)
+        dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
+        grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
+        return (grad_inputs, grad_prompts, None)

+ 129 - 92
src/server/handler.py

@@ -1,8 +1,16 @@
 import contextlib
 import contextlib
-from typing import AsyncIterator, Dict, Sequence
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
 
 
 import torch
 import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
+from hivemind import (
+    DHT,
+    MSGPackSerializer,
+    P2PContext,
+    TensorDescriptor,
+    deserialize_torch_tensor,
+    nested_flatten,
+    serialize_torch_tensor,
+)
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
@@ -12,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
 from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.utils.misc import DUMMY, is_dummy
 
 
 
 
 class TransformerConnectionHandler(ConnectionHandler):
 class TransformerConnectionHandler(ConnectionHandler):
@@ -33,7 +42,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         try:
         try:
             print("OPENED RPC_INFERENCE")
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
             request = await anext(requests)
-            requested_uids = self._check_header(request)
+            requested_uids = self._check_uids(request.uid)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
             batch_size = request.tensors[0].size[0] if request.tensors else 1
@@ -80,27 +89,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
         # Parse request and prepare backends
-        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
 
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-
-        # Serialize the overall output and respond
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+        # Serialize output and respond to client
         return runtime_pb2.ExpertResponse(
         return runtime_pb2.ExpertResponse(
             tensors=[
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
             ]
             ]
         )
         )
 
 
@@ -108,29 +108,20 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
         # Parse requests and prepare backends
-        uids_header, hidden_states = await self._gather_inputs(requests, context)
-        requested_uids = self._check_header_str(uids_header)
+        uid_str, flat_inputs = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        hidden_states = await _rpc_forward(flat_inputs, requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
 
         # Serialize the overall output
         # Serialize the overall output
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
         serialized_output = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
         ]
         ]
 
 
-        # Split the serialized_output for streaming and respond
+        # Split the serialized_output for streaming and respond to client
         output_split = [
         output_split = [
             part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
             part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
         ]
@@ -139,36 +130,25 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse requests and prepare backends
         # Parse requests and prepare backends
-        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its output
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a chain of requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )  # TODO generalize
 
 
         # Serialize the overall grad_input and respond
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
         return runtime_pb2.ExpertResponse(
             tensors=[
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
             ]
             ]
         )
         )
 
 
@@ -176,36 +156,23 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
 
-        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
-        inputs, grads = inputs_and_grads
-        requested_uids = self._check_header_str(uids_header)
+        uids_header, flat_tensors = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its outputs
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a backward chain for requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )  # TODO generalize
 
 
         # Serialize the overall grad_inputs
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
         serialized_grad_inputs = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+            for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
         ]
         ]
         # Split the serialized_grad_inputs for streaming and respond
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
         output_split = [
@@ -215,19 +182,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         async for part in as_aiter(*output_split):
         async for part in as_aiter(*output_split):
             yield runtime_pb2.ExpertResponse(tensors=[part])
             yield runtime_pb2.ExpertResponse(tensors=[part])
 
 
-    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
+    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""
         """Check that the first request to rpc_inference is valid"""
-        uids = (request.uid or "").split(CHAIN_DELIMITER)
-        if not uids:
-            raise RuntimeError("User did not provide any uids")
-        for uid in uids:
-            if uid not in self.module_backends:
-                raise RuntimeError(f"Remote peer does not serve {uid}")
-        return tuple(uids)
-
-    def _check_header_str(self, header) -> Sequence[ModuleUID]:
-        """Check that the first request to rpc_inference is valid"""
-        uids = (header or "").split(CHAIN_DELIMITER)
+        uids = (uids or "").split(CHAIN_DELIMITER)
         if not uids:
         if not uids:
             raise RuntimeError("User did not provide any uids")
             raise RuntimeError("User did not provide any uids")
         for uid in uids:
         for uid in uids:
@@ -252,3 +209,83 @@ class TransformerConnectionHandler(ConnectionHandler):
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
 
 
             yield handles
             yield handles
+
+
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, *prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    if not prompts or is_dummy(prompts[0]):
+        prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
+    else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        if not is_dummy(prompt):
+            hidden_states[:, :pre_seq_len] += prompt
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    # Serialize the overall output
+    return hidden_states
+
+
+async def _rpc_backward(
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    inputs, grad_outputs, *prompts = flat_tensors
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if not prompts or is_dummy(prompts[0]):
+        prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
+    else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = []
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        if not is_dummy(prompt):
+            inputs[:, :pre_seq_len] += prompt
+        inter_inputs.append(inputs)
+        (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(inputs, torch.Tensor)
+
+    if not is_dummy(prompts[-1]):
+        inputs[:, :pre_seq_len] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 7 - 0
src/utils/misc.py

@@ -0,0 +1,7 @@
+import torch
+
+DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
+
+
+def is_dummy(tensor: torch.Tensor):
+    return tensor.numel() == 0

+ 46 - 0
tests/test_remote_sequential.py

@@ -4,6 +4,7 @@ from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 from test_utils import *
 
 
 from src import RemoteSequential
 from src import RemoteSequential
+from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_model import DistributedBloomConfig
 from src.client.remote_model import DistributedBloomConfig
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
@@ -41,3 +42,48 @@ def test_remote_sequential():
 
 
     (second_half_outputs * grad_proj).sum().backward()
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad)
     assert torch.allclose(test_inputs.grad, full_grad)
+
+
+@pytest.mark.forked
+def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    remote_sequential = RemoteSequential(config, dht)
+
+    inputs = torch.randn(batch_size, seq_len, config.hidden_size)
+    output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
+    input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+
+    input_prompts = input_prompts.detach().requires_grad_(True)
+    intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
+
+    inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
+    assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
+
+    outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
+
+    (outputs * output_proj).sum().backward()
+    assert intermediate_prompts.grad is not None
+
+    input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
+    intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
+
+    assert input_prompts_ref.grad is None
+    assert intermediate_prompts_ref.grad is None
+
+    outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
+    for block_index in range(config.n_layer):
+        block_prompt = intermediate_prompts_ref[block_index]
+        outputs_ref[:, : block_prompt.shape[1]] += block_prompt
+
+        block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
+        (outputs_ref,) = block(outputs_ref)
+
+    assert torch.allclose(outputs_ref, outputs)
+
+    (outputs_ref * output_proj).sum().backward()
+    assert input_prompts_ref.grad is not None
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+    assert intermediate_prompts_ref.grad is not None
+    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)