Преглед на файлове

Sequential and parallel forward / backward (#36)

Dmitry Baranchuk преди 3 години
родител
ревизия
6573076883
променени са 6 файла, в които са добавени 279 реда и са изтрити 72 реда
  1. 2 15
      cli/deploy_server.sh
  2. 4 6
      cli/run_local_servers.sh
  3. 4 6
      cli/run_remote_servers.sh
  4. 46 30
      src/client/remote_model.py
  5. 3 15
      src/client/remote_sequential.py
  6. 220 0
      src/client/sequential_autograd.py

+ 2 - 15
cli/deploy_server.sh

@@ -62,23 +62,10 @@ else
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
 fi
 
-
-##############
-# Local test #
-##############
-
-if [ "$RUN_LOCAL_TESTS" = true ] ; then
-    echo "Run test on your local machine"
-    python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
-fi
-
-
 ##############
 # Run server #
 ##############

+ 4 - 6
cli/run_local_servers.sh

@@ -32,17 +32,15 @@ done
 ###########################
 
 source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
+if conda env list | grep ".*bloom-demo.*"  >/dev/null 2>/dev/null; then
     conda activate bloom-demo
 else
     conda create -y --name bloom-demo python=3.8.12 pip
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
 fi
 
 
@@ -88,7 +86,7 @@ do
     done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
     
     echo "=== Server #${SERVER_ID} ==="
-    echo "Server ID: ${id_path}"
+    echo "Server ID: ${cfg[id_path]}"
     echo "Device: ${cfg[device]}"
     echo "Bloom block ids: ${cfg[block_ids]}"
     echo "Host maddr: ${cfg[maddr]}"

+ 4 - 6
cli/run_remote_servers.sh

@@ -37,17 +37,15 @@ done
 ###########################
 
 source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
+if conda env list | grep ".*bloom-demo.*"  >/dev/null 2>/dev/null; then
     conda activate bloom-demo
 else
     conda create -y --name bloom-demo python=3.8.12 pip
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
 fi
 
 
@@ -57,7 +55,7 @@ fi
 
 hivemind-dht &> tmp.out &
 
-sleep 3
+sleep 5
 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
 rm tmp.out
 echo "Initial peer: ${INITIAL_PEER}"

+ 46 - 30
src/client/remote_model.py

@@ -1,11 +1,11 @@
 # this code is in active development, interfaces may change
-import os
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple
 
 import hivemind
 import torch
 import torch.nn as nn
 from hivemind import get_logger, use_hivemind_log_handler
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
 from src.bloom.model import (
     BloomConfig,
@@ -66,13 +66,45 @@ class DistributedBloomModel(BloomModel):
         for p in self.parameters():
             p.requires_grad = value
 
-    def forward(self, *args, use_cache=None, **kwargs):
-        if use_cache:
-            raise ValueError(
-                "Distributed forward does not support use_cache; for efficient cache-aware generation, "
-                "please use model.transformer.inference_session() or model.generate(...)"
-            )
-        return super().forward(*args, use_cache=False, **kwargs)
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        output_shape = input_shape + (hidden_states.size(-1),)
+        hidden_states = self.h(hidden_states)
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
 
 
 class DistributedBloomPrefix(DistributedBloomModel):
@@ -94,16 +126,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
     def forward(
         self,
-        input_ids: Optional[torch.LongTensor],
-        inputs_embeds: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor],
-        past_key_values=None,
-        position_ids=None,
-        head_mask=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
     ):
         assert (
             input_ids is None or inputs_embeds is None
@@ -122,17 +148,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         prompts = self.get_prompt(batch_size)
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds,
-            attention_mask=attention_mask,
-            past_key_values=past_key_values,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
+        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
 
         # Remove prefix
         last_hidden_state = transformer_outputs[0][:, self.prefix_length :]

+ 3 - 15
src/client/remote_sequential.py

@@ -12,6 +12,7 @@ import src
 from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
+from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 
@@ -52,21 +53,8 @@ class RemoteSequential(nn.Module):
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
     def forward(self, inputs: torch.Tensor):
-        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block in iter(self):
-            for retry_index in range(self.sequence_manager.max_retries):
-                try:
-                    (outputs,) = block(inputs)
-                    assert isinstance(outputs, torch.Tensor)
-                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
-                    inputs = outputs
-                    break
-                except Exception as e:
-                    if retry_index == self.sequence_manager.max_retries - 1:
-                        raise e
-                    else:
-                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
-        return inputs
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
+        return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
         assert isinstance(ix, (int, slice))

+ 220 - 0
src/client/sequential_autograd.py

@@ -0,0 +1,220 @@
+import asyncio
+import logging
+from typing import List, Optional, Sequence, Tuple
+
+import torch
+from hivemind import serialize_torch_tensor
+from hivemind.moe.client.expert import expert_backward, expert_forward
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
+
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.server.handler import TransformerConnectionHandler
+
+MAX_TOKENS_IN_BATCH = 1024
+
+
+async def run_expert_forward(
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
+) -> Tuple[torch.Tensor, ...]:
+    """
+    Serializes input tensors and calls "expert_forward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+    # detach to avoid pickling the computation graph
+    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+    forward_inputs = (inputs, kwargs)
+
+    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+    forward_inputs = nested_flatten(forward_inputs)
+    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+    # TODO: figure out whether we should use run_in_executor here
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, proto.compression)
+        for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+    )
+    deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
+    flat_outputs = tuple(deserialized_outputs)
+
+    return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
+
+
+async def run_expert_backward(
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    intemediate_inputs: List[torch.Tensor],
+    grad_outputs: List[torch.Tensor],
+) -> Sequence[torch.Tensor]:
+    """
+    Serializes grad outputs and calls "expert_backward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+    inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+
+    serialized_tensors = (
+        serialize_torch_tensor(tensor, proto.compression)
+        for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+    )
+    deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
+    return deserialized_grad_inputs
+
+
+async def sequential_forward(
+    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+    """
+    Constructs a routing path from <start_index> to <end_index>.
+    Performs chained forward for each subsequence of blocks on the path.
+    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+    """
+
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
+
+    end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
+    assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+
+    sequences = sequence_manager.make_sequence(start_index, end_index)
+    intermediate_inputs = []
+    done_sequences = []
+
+    while len(sequences) > 0:
+        while True:
+            try:
+                span = sequences.pop(0)
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+
+                assert isinstance(outputs, torch.Tensor)
+                assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
+
+                # Save intermediate inputs and subsequences if the forward is already done for them
+                intermediate_inputs.append(inputs)
+                done_sequences.append(span)
+
+                inputs = outputs
+                break
+            except Exception as e:
+                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                backup_sequences = sequence_manager.make_sequence(span.start)
+                assert backup_sequences[0].start == span.start
+                sequences = backup_sequences
+
+    return outputs, intermediate_inputs, done_sequences
+
+
+async def sequential_backward(
+    grad_outputs: Sequence[torch.Tensor],
+    intermediate_inputs: Sequence[torch.Tensor],
+    forward_sequences: Sequence[RemoteSpanInfo],
+    sequence_manager: RemoteSequenceManager,
+) -> Sequence[torch.Tensor]:
+    """
+    Performs chained backward for each forward subsequence.
+    If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
+    """
+
+    assert len(intermediate_inputs) == len(forward_sequences)
+    # TODO think about grads w.r.t. deep prompts
+
+    while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+        while True:
+            try:
+                inputs = intermediate_inputs.pop(-1)
+                span = forward_sequences.pop(-1)
+
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+
+                grad_outputs = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                )
+                break
+            except Exception as e:
+                logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
+                _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
+                    inputs, sequence_manager, start_index=span.start, end_index=span.end
+                )
+
+                assert len(intermediate_inputs) == len(forward_sequences)
+                assert backup_forward_sequences[0].start == span.start
+                assert backup_forward_sequences[-1].end == span.end
+
+                forward_sequences.extend(backup_forward_sequences)
+                intermediate_inputs.extend(backup_intermediate_inputs)
+    return grad_outputs
+
+
+async def _gather_forward(input_batches, sequence_manager):
+    """Wrapper for asyncio.gather to perform parallel sequential forwards"""
+    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
+
+
+async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+    """Wrapper for asyncio.gather to perform parallel sequential backwards"""
+    return await asyncio.gather(
+        *[
+            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
+            for grad_output, input_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, forward_sequences
+            )
+        ]
+    )
+
+
+class _RemoteSequentialAutogradFunction(torch.autograd.Function):
+    """
+    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
+    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
+    """
+
+    @staticmethod
+    def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
+        batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
+        input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
+
+        sequence_manager.rpc_info  # lazy init
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
+        assert len(outputs) == len(input_batches)
+
+        output_batches = [output[0] for output in outputs]
+        intemediate_input_batches = [output[1] for output in outputs]
+        sequences_for_batches = [output[2] for output in outputs]
+
+        ctx.sequence_manager = sequence_manager
+        ctx.intemediate_input_batches = intemediate_input_batches
+        ctx.sequences_for_batches = sequences_for_batches
+        return torch.cat(output_batches, dim=0)
+
+    @staticmethod
+    def backward(ctx, grad_outputs: torch.Tensor):
+        intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
+        forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
+        ctx.sequence_manager.rpc_info  # lazy init
+
+        batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
+        grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
+        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
+
+        grad_input_batches = RemoteExpertWorker.run_coroutine(
+            _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
+        )
+        grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
+        grad_inputs = torch.cat(grad_inputs, dim=0)
+        return (grad_inputs, None)