Ver código fonte

distributed deep & shallow ptune

dbaranchuk 3 anos atrás
pai
commit
d87b8d69b6

+ 6 - 0
src/bloom/block.py

@@ -18,6 +18,7 @@ from src.bloom.ops import (
     pre_process_alibi_for_pad,
     split_tensor_along_last_dim,
 )
+from src.utils.misc import DUMMY, is_dummy
 
 
 class BloomAttention(nn.Module):
@@ -202,6 +203,7 @@ class BloomBlock(nn.Module):
     def forward(
         self,
         hidden_states,
+        prompts=DUMMY,
         layer_past=None,
         attention_mask=None,
         head_mask=None,
@@ -247,6 +249,10 @@ class BloomBlock(nn.Module):
         # MLP.
         output = self.mlp(layernorm_output, residual)
 
+        if not is_dummy(prompts):
+            pre_seq_len = prompts.shape[1]
+            output[:, :pre_seq_len] = output[:, :pre_seq_len] + prompts
+
         if use_cache:
             outputs = (output,) + outputs
         else:

+ 49 - 52
src/client/remote_model.py

@@ -17,6 +17,7 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
+from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -33,6 +34,9 @@ class DistributedBloomConfig(BloomConfig):
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[
+        str
+    ] = None  # One of the available options for fine-tuning: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
 
 
 class DistributedBloomModel(BloomModel):
@@ -60,10 +64,40 @@ class DistributedBloomModel(BloomModel):
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
 
+        self.tuning_mode = config.tuning_mode
+        if self.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            if config.tuning_mode == "deep_ptune":
+                self.intermediate_prompt_embeddings = nn.Embedding(
+                    self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
+                )
+                self.intermediate_prompt_embeddings.weight.data.zero_()
+        elif self.tuning_mode:
+            raise NotImplementedError(f"TODO: {self.tuning_mode}")
+
     def set_requires_grad(self, value):
         for p in self.parameters():
             p.requires_grad = value
 
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size, self.pre_seq_len, len(self.h) - 1, self.hidden_size
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+        return prompts, intermediate_prompts
+
     def forward(
         self,
         input_ids: Optional[torch.LongTensor] = None,
@@ -90,10 +124,22 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        # Note: it supports only float32 or bfloat16 inputs
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        if self.tuning_mode and "ptune" in self.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
         output_shape = input_shape + (hidden_states.size(-1),)
-        hidden_states = self.h(hidden_states)
+
+        if "ptune" in self.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.tuning_mode and "ptune" in self.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state
         hidden_states = self.ln_f(hidden_states)
@@ -106,55 +152,6 @@ class DistributedBloomModel(BloomModel):
         )
 
 
-class DistributedBloomPrefix(DistributedBloomModel):
-    """DistributedBloomModel with prefix tokens for prompt tuning"""
-
-    def __init__(self, config):
-        super().__init__(config)
-        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-        self.pre_seq_len = config.pre_seq_len
-
-        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
-        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-        return prompts
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
-    ):
-        assert (
-            input_ids is None or inputs_embeds is None
-        ), "You cannot specify both input_ids and inputs_embeds at the same time"
-        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        batch_size = inputs_embeds.shape[0]
-
-        if attention_mask is not None:
-            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
-            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
-
-        prompts = self.get_prompt(batch_size)
-        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-
-        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
-
-        # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
-        transformer_outputs["last_hidden_state"] = last_hidden_state
-        return transformer_outputs
-
-
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 

+ 3 - 2
src/client/remote_sequential.py

@@ -15,6 +15,7 @@ from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
+from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -52,8 +53,8 @@ class RemoteSequential(nn.Module):
             assert isinstance(sequence_manager.block_uids, list)
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
-    def forward(self, inputs: torch.Tensor):
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
+    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:

+ 62 - 25
src/client/sequential_autograd.py

@@ -12,6 +12,7 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 MAX_TOKENS_IN_BATCH = 1024
 
@@ -57,7 +58,7 @@ async def run_expert_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    intemediate_inputs: List[torch.Tensor],
+    inputs: List[torch.Tensor],
     grad_outputs: List[torch.Tensor],
 ) -> Sequence[torch.Tensor]:
     """
@@ -67,7 +68,7 @@ async def run_expert_backward(
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
 
     # Asynchronous serialization
@@ -84,7 +85,11 @@ async def run_expert_backward(
 
 
 async def sequential_forward(
-    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
+    inputs: torch.Tensor,
+    prompts: torch.Tensor,
+    sequence_manager: RemoteSequenceManager,
+    start_index: int = 0,
+    end_index: Optional[int] = None,
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
@@ -96,6 +101,8 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+    if not is_dummy(prompts):
+        assert len(prompts) == end_index - start_index + 1
 
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
@@ -107,7 +114,9 @@ async def sequential_forward(
                 span = sequences.pop(0)
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -119,7 +128,7 @@ async def sequential_forward(
                 inputs = outputs
                 break
             except Exception as e:
-                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                logging.warn(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
                 sequences = backup_sequences
@@ -130,6 +139,7 @@ async def sequential_forward(
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: Sequence[torch.Tensor],
+    prompts: Sequence[torch.Tensor],
     forward_sequences: Sequence[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
@@ -137,10 +147,9 @@ async def sequential_backward(
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
-
     assert len(intermediate_inputs) == len(forward_sequences)
-    # TODO think about grads w.r.t. deep prompts
 
+    grad_prompts = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
             try:
@@ -150,37 +159,50 @@ async def sequential_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
-                grad_outputs = await run_expert_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+                grad_outputs, span_grad_prompts = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs_and_prompts, grad_outputs
                 )
+                grad_prompts.append(span_grad_prompts)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
-                    inputs, sequence_manager, start_index=span.start, end_index=span.end
+                    inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
                 )
-
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert backup_forward_sequences[0].start == span.start
                 assert backup_forward_sequences[-1].end == span.end
 
                 forward_sequences.extend(backup_forward_sequences)
                 intermediate_inputs.extend(backup_intermediate_inputs)
-    return grad_outputs
+
+    dummy_grad_prompts = [is_dummy(grad_prompt) for grad_prompt in grad_prompts]
+    # For now, we do not support mixed dummy and grad prompts
+    # Concat in num_layer dimension
+    grad_prompts = torch.cat(grad_prompts, dim=0) if not any(dummy_grad_prompts) else None
+    return grad_outputs, grad_prompts
 
 
-async def _gather_forward(input_batches, sequence_manager):
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
-    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
+    return await asyncio.gather(
+        *[
+            sequential_forward(input_batch, prompt_batch, sequence_manager)
+            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+        ]
+    )
 
 
-async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+async def _gather_backward(
+    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
-            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
-            for grad_output, input_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, forward_sequences
+            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
             )
         ]
     )
@@ -193,18 +215,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        if is_dummy(prompts):
+            prompt_batches = [DUMMY] * len(input_batches)
+        else:
+            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 
         output_batches = [output[0] for output in outputs]
         intemediate_input_batches = [output[1] for output in outputs]
         sequences_for_batches = [output[2] for output in outputs]
 
+        ctx.prompt_batches = prompt_batches
         ctx.sequence_manager = sequence_manager
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.sequences_for_batches = sequences_for_batches
@@ -220,9 +247,19 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
         assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
 
-        grad_input_batches = RemoteExpertWorker.run_coroutine(
-            _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
+        outputs = RemoteExpertWorker.run_coroutine(
+            _gather_backward(
+                grad_output_batches,
+                intermediate_input_batches,
+                ctx.prompt_batches,
+                forward_sequences,
+                ctx.sequence_manager,
+            )
         )
-        grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
-        grad_inputs = torch.cat(grad_inputs, dim=0)
-        return (grad_inputs, None)
+        grad_input_batches = [output[0] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
+
+        grad_inputs = torch.cat(grad_input_batches, dim=0)
+        dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
+        grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
+        return (grad_inputs, grad_prompts, None)

+ 65 - 67
src/server/handler.py

@@ -12,6 +12,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.utils.misc import DUMMY, is_dummy
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -80,20 +81,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
-        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        hidden_states = await _rpc_forward(inputs, requested_backends)
 
         # Serialize the overall output and respond
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
@@ -108,20 +100,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
-        uids_header, hidden_states = await self._gather_inputs(requests, context)
+        uids_header, inputs = await self._gather_inputs(requests, context)
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs to backend dtype
-        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        hidden_states = await _rpc_forward(inputs, requested_backends)
 
         # Serialize the overall output
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
@@ -139,36 +122,17 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse requests and prepare backends
-        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        inputs, prompts, grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its output
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a chain of requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+                for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
             ]
         )
 
@@ -176,36 +140,16 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
-        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
-        inputs, grads = inputs_and_grads
+        uids_header, (inputs, prompts, grad_outputs) = await self._gather_inputs(requests, context)
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Cast inputs & grad outputs to backend dtype
-        inputs = inputs.to(requested_backends[0].dtype)
-        grads = grads.to(requested_backends[-1].dtype)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its outputs
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a backward chain for requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
+        grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+            for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
         ]
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
@@ -252,3 +196,57 @@ class TransformerConnectionHandler(ConnectionHandler):
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
 
             yield handles
+
+
+async def _rpc_forward(inputs, requested_backends):
+    # Cast inputs to backend dtype
+    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
+    assert len(hidden_states) == 2 and hidden_states[0].ndim == 3
+    hidden_states, prompts = hidden_states
+
+    if is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, prompt)
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    # Serialize the overall output
+    return [hidden_states]
+
+
+async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    prompts = prompts.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = [inputs]
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        inputs = await backend.forward_pool.submit_task(inputs, prompt)
+        assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
+        inputs = inputs[0]
+        inter_inputs.append(inputs)
+
+    grad_prompts = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
+        grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
+        assert isinstance(grads, (list, tuple)) and len(grads) == 2
+        grad_outputs, grad_prompt = grads
+        grad_prompts.append(grad_prompt)
+
+    is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
+    grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
+    grads = [grad_outputs, grad_prompts]
+    return grads

+ 3 - 0
src/server/server.py

@@ -212,6 +212,9 @@ class Server(threading.Thread):
                     BatchTensorDescriptor(
                         1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
                     ),
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
                 ),
                 kwargs_schema={},
                 outputs_schema=(

+ 7 - 0
src/utils/misc.py

@@ -0,0 +1,7 @@
+import torch
+
+DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
+
+
+def is_dummy(tensor: torch.Tensor):
+    return tensor.numel() == 0