Quellcode durchsuchen

Deep distributed prompt tuning (#42)

* implemented an option to add learnable prompts to intermediate layers
* added support for prompts (as input) in rpc_forward and rpc_backward
* added a test to check that RemoteSequential works correctly with deep prompts

Co-authored-by: justheuristic <justheuristic@gmail.com>
Dmitry Baranchuk vor 3 Jahren
Ursprung
Commit
6095f58681

+ 51 - 61
src/client/remote_model.py

@@ -1,5 +1,5 @@
 # this code is in active development, interfaces may change
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple
 
 import hivemind
 import torch
@@ -17,6 +17,7 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
+from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -33,6 +34,7 @@ class DistributedBloomConfig(BloomConfig):
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
 
 
 class DistributedBloomModel(BloomModel):
@@ -60,10 +62,41 @@ class DistributedBloomModel(BloomModel):
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
 
+        if config.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            if config.tuning_mode == "deep_ptune":
+                self.intermediate_prompt_embeddings = nn.Embedding(
+                    self.pre_seq_len,
+                    config.num_hidden_layers * config.hidden_size
+                    # ^-- TODO: should be num_hidden_layers - 1
+                )
+                self.intermediate_prompt_embeddings.weight.data.zero_()
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
     def set_requires_grad(self, value):
         for p in self.parameters():
             p.requires_grad = value
 
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.config.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+        return prompts, intermediate_prompts
+
     def forward(
         self,
         input_ids: Optional[torch.LongTensor] = None,
@@ -90,10 +123,22 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        # Note: it supports only float32 or bfloat16 inputs
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
         output_shape = input_shape + (hidden_states.size(-1),)
-        hidden_states = self.h(hidden_states)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state
         hidden_states = self.ln_f(hidden_states)
@@ -106,55 +151,6 @@ class DistributedBloomModel(BloomModel):
         )
 
 
-class DistributedBloomPrefix(DistributedBloomModel):
-    """DistributedBloomModel with prefix tokens for prompt tuning"""
-
-    def __init__(self, config):
-        super().__init__(config)
-        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-        self.pre_seq_len = config.pre_seq_len
-
-        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
-        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-        return prompts
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
-    ):
-        assert (
-            input_ids is None or inputs_embeds is None
-        ), "You cannot specify both input_ids and inputs_embeds at the same time"
-        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        batch_size = inputs_embeds.shape[0]
-
-        if attention_mask is not None:
-            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
-            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
-
-        prompts = self.get_prompt(batch_size)
-        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-
-        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
-
-        # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
-        transformer_outputs["last_hidden_state"] = last_hidden_state
-        return transformer_outputs
-
-
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
@@ -162,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
         self.lm_head = LMHead(config, self.transformer.word_embeddings)
 
         # Initialize weights and apply final processing
@@ -195,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
 
         # Initialize weights and apply final processing
         self.post_init()

+ 3 - 2
src/client/remote_sequential.py

@@ -15,6 +15,7 @@ from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
+from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -52,8 +53,8 @@ class RemoteSequential(nn.Module):
             assert isinstance(sequence_manager.block_uids, list)
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
-    def forward(self, inputs: torch.Tensor):
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
+    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:

+ 83 - 35
src/client/sequential_autograd.py

@@ -12,6 +12,7 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 MAX_TOKENS_IN_BATCH = 1024
 
@@ -33,7 +34,13 @@ async def run_expert_forward(
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
 
-    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
     forward_inputs = nested_flatten(forward_inputs)
@@ -44,7 +51,7 @@ async def run_expert_forward(
     serialized_tensors = await asyncio.gather(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
         )
     )
 
@@ -57,8 +64,9 @@ async def run_expert_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    intemediate_inputs: List[torch.Tensor],
+    inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
     """
     Serializes grad outputs and calls "expert_backward".
@@ -67,8 +75,14 @@ async def run_expert_backward(
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
@@ -84,7 +98,11 @@ async def run_expert_backward(
 
 
 async def sequential_forward(
-    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
+    inputs: torch.Tensor,
+    prompts: torch.Tensor,
+    sequence_manager: RemoteSequenceManager,
+    start_index: int = 0,
+    end_index: Optional[int] = None,
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
@@ -96,6 +114,9 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+    assert is_dummy(prompts) or len(prompts) == len(
+        sequence_manager.block_uids
+    )  # should be n_layers - 1 but add extra prompts for convenience
 
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
@@ -107,7 +128,9 @@ async def sequential_forward(
                 span = sequences.pop(0)
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -119,7 +142,7 @@ async def sequential_forward(
                 inputs = outputs
                 break
             except Exception as e:
-                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
                 sequences = backup_sequences
@@ -129,58 +152,68 @@ async def sequential_forward(
 
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
-    intermediate_inputs: Sequence[torch.Tensor],
-    forward_sequences: Sequence[RemoteSpanInfo],
+    intermediate_inputs: List[torch.Tensor],
+    prompts: torch.Tensor,
+    forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
-
     assert len(intermediate_inputs) == len(forward_sequences)
-    # TODO think about grads w.r.t. deep prompts
 
+    grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
+            inputs = intermediate_inputs.pop(-1)
+            span = forward_sequences.pop(-1)
+            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
-                inputs = intermediate_inputs.pop(-1)
-                span = forward_sequences.pop(-1)
-
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-
-                grad_outputs = await run_expert_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                grad_outputs, *span_grad_prompts = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
                 )
+                grad_outputs = [grad_outputs]
+                grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
-                    inputs, sequence_manager, start_index=span.start, end_index=span.end
+                    inputs, prompts[span.start : span.end], sequence_manager, start_index=span.start, end_index=span.end
                 )
-
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert backup_forward_sequences[0].start == span.start
                 assert backup_forward_sequences[-1].end == span.end
 
                 forward_sequences.extend(backup_forward_sequences)
                 intermediate_inputs.extend(backup_intermediate_inputs)
-    return grad_outputs
+
+    # For now, we do not support mixed dummy and grad prompts
+    # Concat in num_layer dimension
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+    return grad_outputs, grad_prompts
 
 
-async def _gather_forward(input_batches, sequence_manager):
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
-    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
+    return await asyncio.gather(
+        *[
+            sequential_forward(input_batch, prompt_batch, sequence_manager)
+            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+        ]
+    )
 
 
-async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+async def _gather_backward(
+    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
-            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
-            for grad_output, input_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, forward_sequences
+            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
             )
         ]
     )
@@ -193,18 +226,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        if is_dummy(prompts):
+            prompt_batches = [DUMMY] * len(input_batches)
+        else:
+            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 
         output_batches = [output[0] for output in outputs]
         intemediate_input_batches = [output[1] for output in outputs]
         sequences_for_batches = [output[2] for output in outputs]
 
+        ctx.prompt_batches = prompt_batches
         ctx.sequence_manager = sequence_manager
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.sequences_for_batches = sequences_for_batches
@@ -220,9 +258,19 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
         assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
 
-        grad_input_batches = RemoteExpertWorker.run_coroutine(
-            _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
+        outputs = RemoteExpertWorker.run_coroutine(
+            _gather_backward(
+                grad_output_batches,
+                intermediate_input_batches,
+                ctx.prompt_batches,
+                forward_sequences,
+                ctx.sequence_manager,
+            )
         )
-        grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
-        grad_inputs = torch.cat(grad_inputs, dim=0)
-        return (grad_inputs, None)
+        grad_input_batches = [output[0][0] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
+
+        grad_inputs = torch.cat(grad_input_batches, dim=0)
+        dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
+        grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
+        return (grad_inputs, grad_prompts, None)

+ 129 - 92
src/server/handler.py

@@ -1,8 +1,16 @@
 import contextlib
-from typing import AsyncIterator, Dict, Sequence
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
 
 import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
+from hivemind import (
+    DHT,
+    MSGPackSerializer,
+    P2PContext,
+    TensorDescriptor,
+    deserialize_torch_tensor,
+    nested_flatten,
+    serialize_torch_tensor,
+)
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
@@ -12,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.utils.misc import DUMMY, is_dummy
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -33,7 +42,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         try:
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
-            requested_uids = self._check_header(request)
+            requested_uids = self._check_uids(request.uid)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
@@ -80,27 +89,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
-        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-
-        # Serialize the overall output and respond
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+        # Serialize output and respond to client
         return runtime_pb2.ExpertResponse(
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
 
@@ -108,29 +108,20 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
-        uids_header, hidden_states = await self._gather_inputs(requests, context)
-        requested_uids = self._check_header_str(uids_header)
+        uid_str, flat_inputs = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        hidden_states = await _rpc_forward(flat_inputs, requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize the overall output
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
-        # Split the serialized_output for streaming and respond
+        # Split the serialized_output for streaming and respond to client
         output_split = [
             part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
@@ -139,36 +130,25 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse requests and prepare backends
-        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its output
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a chain of requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )  # TODO generalize
 
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
             ]
         )
 
@@ -176,36 +156,23 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
-        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
-        inputs, grads = inputs_and_grads
-        requested_uids = self._check_header_str(uids_header)
+        uids_header, flat_tensors = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its outputs
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a backward chain for requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )  # TODO generalize
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+            for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
         ]
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
@@ -215,19 +182,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         async for part in as_aiter(*output_split):
             yield runtime_pb2.ExpertResponse(tensors=[part])
 
-    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
+    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""
-        uids = (request.uid or "").split(CHAIN_DELIMITER)
-        if not uids:
-            raise RuntimeError("User did not provide any uids")
-        for uid in uids:
-            if uid not in self.module_backends:
-                raise RuntimeError(f"Remote peer does not serve {uid}")
-        return tuple(uids)
-
-    def _check_header_str(self, header) -> Sequence[ModuleUID]:
-        """Check that the first request to rpc_inference is valid"""
-        uids = (header or "").split(CHAIN_DELIMITER)
+        uids = (uids or "").split(CHAIN_DELIMITER)
         if not uids:
             raise RuntimeError("User did not provide any uids")
         for uid in uids:
@@ -252,3 +209,83 @@ class TransformerConnectionHandler(ConnectionHandler):
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
 
             yield handles
+
+
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, *prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    if not prompts or is_dummy(prompts[0]):
+        prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
+    else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        if not is_dummy(prompt):
+            hidden_states[:, :pre_seq_len] += prompt
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    # Serialize the overall output
+    return hidden_states
+
+
+async def _rpc_backward(
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    inputs, grad_outputs, *prompts = flat_tensors
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if not prompts or is_dummy(prompts[0]):
+        prompts = [DUMMY] * len(requested_backends)
+        pre_seq_len = 0
+    else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
+        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        pre_seq_len = prompts[0].shape[-2]
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = []
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        if not is_dummy(prompt):
+            inputs[:, :pre_seq_len] += prompt
+        inter_inputs.append(inputs)
+        (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(inputs, torch.Tensor)
+
+    if not is_dummy(prompts[-1]):
+        inputs[:, :pre_seq_len] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 7 - 0
src/utils/misc.py

@@ -0,0 +1,7 @@
+import torch
+
+DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
+
+
+def is_dummy(tensor: torch.Tensor):
+    return tensor.numel() == 0

+ 46 - 0
tests/test_remote_sequential.py

@@ -4,6 +4,7 @@ from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
 from src import RemoteSequential
+from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
@@ -41,3 +42,48 @@ def test_remote_sequential():
 
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad)
+
+
+@pytest.mark.forked
+def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    remote_sequential = RemoteSequential(config, dht)
+
+    inputs = torch.randn(batch_size, seq_len, config.hidden_size)
+    output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
+    input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+
+    input_prompts = input_prompts.detach().requires_grad_(True)
+    intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
+
+    inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
+    assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
+
+    outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
+
+    (outputs * output_proj).sum().backward()
+    assert intermediate_prompts.grad is not None
+
+    input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
+    intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
+
+    assert input_prompts_ref.grad is None
+    assert intermediate_prompts_ref.grad is None
+
+    outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
+    for block_index in range(config.n_layer):
+        block_prompt = intermediate_prompts_ref[block_index]
+        outputs_ref[:, : block_prompt.shape[1]] += block_prompt
+
+        block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
+        (outputs_ref,) = block(outputs_ref)
+
+    assert torch.allclose(outputs_ref, outputs)
+
+    (outputs_ref * output_proj).sum().backward()
+    assert input_prompts_ref.grad is not None
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+    assert intermediate_prompts_ref.grad is not None
+    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)