Bladeren bron

Merge branch 'main' into optimize_seq

justheuristic 3 jaren geleden
bovenliggende
commit
0cb9af4374

+ 13 - 4
.github/workflows/run-tests.yaml

@@ -28,12 +28,12 @@ jobs:
           pip install -r requirements.txt
       - name: Delete previous model, if exists
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_REF_NAME') or 'main')")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
           name='test-bloomd-350m-$HF_TAG', organization='bloom-testing')" || true
       - name: Convert model and push to hub
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           python -m cli.convert_model --model bigscience/bloom-350m  --output_path ./converted_model \
             --output_repo bloom-testing/test-bloomd-350m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
 
@@ -62,9 +62,16 @@ jobs:
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes cpuonly
+        run: |
+          git clone https://github.com/TimDettmers/bitsandbytes.git
+          cd bitsandbytes
+          make cpuonly
+          pip install .
+          cd -
       - name: Test
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_REF_NAME') or 'main')")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
           export REF_NAME=bigscience/bloom-350m
 
@@ -72,6 +79,8 @@ jobs:
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
           SERVER1_PID=$!
           
+          sleep 5  # wait for the first server to initialize DHT
+          
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
           
@@ -79,7 +88,7 @@ jobs:
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
           SERVER2_PID=$!
 
-          sleep 30  # wait for server to download layers
+          sleep 60  # wait for server to download layers
           
           PYTHONPATH=. pytest tests
           

+ 4 - 3
README.md

@@ -1,11 +1,11 @@
-# Distributed BLOOM
+# PETALS: Collaborative Inference of Large Models
 
-Run the largest open language model by collaborating over the internet.
+Run BLOOM-176B, the largest open language model, by collaborating over the Internet.
 
 __[EARLY PROTOTYPE]__ - this project is a work in progress. Stuff breaks and gets fixed every day. Docs are nonexistent.
 If you want us to wake you up when it's ready, click Watch -> Custom and tick "Releases".
 
-Roadmap: [__issue #12__](https://github.com/learning-at-home/bloom-demo/issues/12)
+Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/12)
 
 ### Installation
 
@@ -13,6 +13,7 @@ Roadmap: [__issue #12__](https://github.com/learning-at-home/bloom-demo/issues/1
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
 pip install -r requirements.txt
+pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 ```
 
 

+ 11 - 19
cli/deploy_server.sh

@@ -5,7 +5,8 @@
 #################
 
 instructions() {
-  echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo " -m: model name"
   echo " -i: initial peer"
   echo " -d: device" >&2
   echo " -p: server identity path" >&2
@@ -19,8 +20,10 @@ if [ ! $# -ge 8 ]; then
     instructions
 fi
 
-while getopts ":i:d:p:b:a:t:" option; do
+while getopts ":m:i:d:p:b:a:t:" option; do
     case $option in
+        m)  MODEL_NAME=${OPTARG}
+            ;;
         i)  INITIAL_PEER=${OPTARG}
             ;;
         d)  DEVICE=${OPTARG}
@@ -42,6 +45,7 @@ done
 echo "=========="
 echo "= Config ="
 echo "=========="
+echo "Model name: ${MODEL_NAME}"
 echo "Initial peer: ${INITIAL_PEER}"
 echo "Device: ${DEVICE}"
 echo "Server name: ${SERVER_ID_PATH}"
@@ -62,26 +66,14 @@ else
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
-fi
-
-
-##############
-# Local test #
-##############
-
-if [ "$RUN_LOCAL_TESTS" = true ] ; then
-    echo "Run test on your local machine"
-    python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
+    pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 fi
 
-
 ##############
 # Run server #
 ##############
 
-python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
-  --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
+python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
+  --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log

+ 7 - 9
cli/run_local_servers.sh

@@ -32,17 +32,16 @@ done
 ###########################
 
 source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
+if conda env list | grep ".*bloom-demo.*"  >/dev/null 2>/dev/null; then
     conda activate bloom-demo
 else
     conda create -y --name bloom-demo python=3.8.12 pip
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
+    pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 fi
 
 
@@ -51,7 +50,7 @@ fi
 #######################
 
 hivemind-dht &> tmp.out &
-sleep 3
+sleep 5
 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
 echo "Initial peer: ${INITIAL_PEER}"
 
@@ -88,7 +87,7 @@ do
     done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
     
     echo "=== Server #${SERVER_ID} ==="
-    echo "Server ID: ${id_path}"
+    echo "Server ID: ${cfg[id_path]}"
     echo "Device: ${cfg[device]}"
     echo "Bloom block ids: ${cfg[block_ids]}"
     echo "Host maddr: ${cfg[maddr]}"
@@ -98,10 +97,9 @@ do
     # Run server #
     ##############
 
-    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
 done
 
-
 #####################
 # Kill initial peer #
 #####################

+ 4 - 6
cli/run_remote_servers.sh

@@ -37,17 +37,15 @@ done
 ###########################
 
 source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
+if conda env list | grep ".*bloom-demo.*"  >/dev/null 2>/dev/null; then
     conda activate bloom-demo
 else
     conda create -y --name bloom-demo python=3.8.12 pip
     conda activate bloom-demo
 
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+    pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+    pip install -i https://pypi.org/simple -r requirements.txt
 fi
 
 
@@ -57,7 +55,7 @@ fi
 
 hivemind-dht &> tmp.out &
 
-sleep 3
+sleep 5
 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
 rm tmp.out
 echo "Initial peer: ${INITIAL_PEER}"

+ 7 - 1
cli/run_server.py

@@ -27,12 +27,14 @@ def main():
 
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
 
-    parser.add_argument('--num_handlers', type=int, default=16, required=False,
+    parser.add_argument('--num_handlers', type=int, default=8, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--cache_dir', type=str, default=None, 
+                        help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--cache_size_bytes', type=int, default=None,
                         help='The size of memory cache for storing past attention keys/values between inference steps')
     parser.add_argument('--device', type=str, default=None, required=False,
@@ -40,6 +42,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
+    parser.add_argument('--revision', type=str, default='main',
+                        help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
+                             "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
                         type=lambda value: value if value in ['auto', 'eval'] else float(value),
@@ -64,6 +69,7 @@ def main():
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+    parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 2 - 3
requirements.txt

@@ -1,6 +1,5 @@
 torch==1.12.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
-bitsandbytes-cuda113==0.26.0
-https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
-https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
+https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 9 - 3
src/bloom/from_pretrained.py

@@ -34,12 +34,15 @@ def load_pretrained_block(
     config: Optional[BloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
+    state_dict = _load_state_dict(
+        converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
     block.load_state_dict(state_dict)
 
     if torch_dtype == "auto":
@@ -57,7 +60,10 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
+    pretrained_model_name_or_path: str,
+    block_index: Optional[int] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@@ -65,7 +71,7 @@ def _load_state_dict(
     # Load from URL or cache if already cached
     resolved_archive_file = cached_path(
         archive_file,
-        cache_dir=None,
+        cache_dir=cache_dir,
         force_download=FORCE_DOWNLOAD,
         proxies=None,
         resume_download=RESUME_DOWNLOAD,

+ 3 - 5
src/bloom/model.py

@@ -156,9 +156,7 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
 
         # Embedding + LN Embedding
-
-        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  # dtype=config.torch_dtype
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
         # Transformer blocks
@@ -229,7 +227,8 @@ class BloomModel(BloomPreTrainedModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
 
         output_shape = input_shape + (hidden_states.size(-1),)
 
@@ -584,7 +583,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
                 )
 
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-
         loss = None
         if labels is not None:
             if self.config.problem_type is None:

+ 1 - 1
src/bloom/ops.py

@@ -101,7 +101,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
         attention_mask: ([`torch.tensor`], *required*):
             attention mask to pre-process
     """
-    assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
+    assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
     unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
     # ^-- [batch, max_len], values correspond to element indices after removing padding
     # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0

+ 1 - 1
src/client/inference_session.py

@@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
+                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
                 )

+ 6 - 0
src/client/remote_generation.py

@@ -17,6 +17,7 @@ class RemoteGenerationMixin:
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     """
 
+    @torch.no_grad()
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
@@ -27,6 +28,7 @@ class RemoteGenerationMixin:
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
+        max_length: Optional[int] = None,
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
@@ -63,6 +65,10 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        if max_length is not None and max_new_tokens is None:
+            max_new_tokens = max_length - inputs.size(1)
+            assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
             inputs = torch.tensor([[bos_token_id]])

+ 55 - 67
src/client/remote_model.py

@@ -1,11 +1,11 @@
 # this code is in active development, interfaces may change
-import os
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple
 
 import hivemind
 import torch
 import torch.nn as nn
 from hivemind import get_logger, use_hivemind_log_handler
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
 from src.bloom.model import (
     BloomConfig,
@@ -17,8 +17,6 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
-from src.utils.generation_algorithms import DecodingAlgorithm
-from src.utils.generation_constraints import ABCBloomConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -34,7 +32,7 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
-    num_prefix_tokens: int = 0  # a number of tokens for prompt tuning.
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
 
 
 class DistributedBloomModel(BloomModel):
@@ -66,13 +64,46 @@ class DistributedBloomModel(BloomModel):
         for p in self.parameters():
             p.requires_grad = value
 
-    def forward(self, *args, use_cache=None, **kwargs):
-        if use_cache:
-            raise ValueError(
-                "Distributed forward does not support use_cache; for efficient cache-aware generation, "
-                "please use model.transformer.inference_session() or model.generate(...)"
-            )
-        return super().forward(*args, use_cache=False, **kwargs)
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        output_shape = input_shape + (hidden_states.size(-1),)
+        hidden_states = self.h(hidden_states)
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
 
 
 class DistributedBloomPrefix(DistributedBloomModel):
@@ -80,11 +111,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
     def __init__(self, config):
         super().__init__(config)
-        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
-        self.prefix_length = config.num_prefix_tokens
+        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+        self.pre_seq_len = config.pre_seq_len
 
-        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.prefix_length).long()
+        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
 
     def get_prompt(self, batch_size):
         prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
@@ -94,16 +125,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
     def forward(
         self,
-        input_ids: Optional[torch.LongTensor],
-        inputs_embeds: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor],
-        past_key_values=None,
-        position_ids=None,
-        head_mask=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
     ):
         assert (
             input_ids is None or inputs_embeds is None
@@ -122,17 +147,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         prompts = self.get_prompt(batch_size)
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds,
-            attention_mask=attention_mask,
-            past_key_values=past_key_values,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
+        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
 
         # Remove prefix
         last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
@@ -140,14 +155,14 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return transformer_outputs
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)
@@ -174,40 +189,13 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
             self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
             self.lm_head.bias[...] = new_lm_head.bias
 
-    def generate(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        do_sample: Optional[bool] = None,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-        eos_token_id: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
-    ) -> torch.Tensor:
-        return RemoteGenerationMixin.generate(
-            self,
-            inputs=inputs,
-            do_sample=do_sample,
-            temperature=temperature,
-            top_k=top_k,
-            top_p=top_p,
-            eos_token_id=eos_token_id,
-            max_new_tokens=max_new_tokens,
-            decoding_algorithm=decoding_algorithm,
-            provided_constraints=provided_constraints,
-            **model_kwargs,
-        )
-
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)

+ 3 - 15
src/client/remote_sequential.py

@@ -12,6 +12,7 @@ import src
 from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.routing.sequence_manager import RemoteSequenceManager
+from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 
@@ -52,21 +53,8 @@ class RemoteSequential(nn.Module):
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
     def forward(self, inputs: torch.Tensor):
-        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block in iter(self):
-            for retry_index in range(self.sequence_manager.max_retries):
-                try:
-                    (outputs,) = block(inputs)
-                    assert isinstance(outputs, torch.Tensor)
-                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
-                    inputs = outputs
-                    break
-                except Exception as e:
-                    if retry_index == self.sequence_manager.max_retries - 1:
-                        raise e
-                    else:
-                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
-        return inputs
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
+        return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
         assert isinstance(ix, (int, slice))

+ 228 - 0
src/client/sequential_autograd.py

@@ -0,0 +1,228 @@
+import asyncio
+import logging
+from typing import List, Optional, Sequence, Tuple
+
+import torch
+from hivemind import serialize_torch_tensor
+from hivemind.moe.client.expert import expert_backward, expert_forward
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
+
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.server.handler import TransformerConnectionHandler
+
+MAX_TOKENS_IN_BATCH = 1024
+
+
+async def run_expert_forward(
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
+) -> Tuple[torch.Tensor, ...]:
+    """
+    Serializes input tensors and calls "expert_forward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+    # detach to avoid pickling the computation graph
+    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+    forward_inputs = (inputs, kwargs)
+
+    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+    forward_inputs = nested_flatten(forward_inputs)
+    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+        )
+    )
+
+    deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
+    flat_outputs = tuple(deserialized_outputs)
+    return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
+
+
+async def run_expert_backward(
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    intemediate_inputs: List[torch.Tensor],
+    grad_outputs: List[torch.Tensor],
+) -> Sequence[torch.Tensor]:
+    """
+    Serializes grad outputs and calls "expert_backward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+    inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
+    )
+
+    deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
+    return deserialized_grad_inputs
+
+
+async def sequential_forward(
+    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+    """
+    Constructs a routing path from <start_index> to <end_index>.
+    Performs chained forward for each subsequence of blocks on the path.
+    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+    """
+
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
+
+    end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
+    assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+
+    sequences = sequence_manager.make_sequence(start_index, end_index)
+    intermediate_inputs = []
+    done_sequences = []
+
+    while len(sequences) > 0:
+        while True:
+            try:
+                span = sequences.pop(0)
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+
+                assert isinstance(outputs, torch.Tensor)
+                assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
+
+                # Save intermediate inputs and subsequences if the forward is already done for them
+                intermediate_inputs.append(inputs)
+                done_sequences.append(span)
+
+                inputs = outputs
+                break
+            except Exception as e:
+                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                backup_sequences = sequence_manager.make_sequence(span.start)
+                assert backup_sequences[0].start == span.start
+                sequences = backup_sequences
+
+    return outputs, intermediate_inputs, done_sequences
+
+
+async def sequential_backward(
+    grad_outputs: Sequence[torch.Tensor],
+    intermediate_inputs: Sequence[torch.Tensor],
+    forward_sequences: Sequence[RemoteSpanInfo],
+    sequence_manager: RemoteSequenceManager,
+) -> Sequence[torch.Tensor]:
+    """
+    Performs chained backward for each forward subsequence.
+    If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
+    """
+
+    assert len(intermediate_inputs) == len(forward_sequences)
+    # TODO think about grads w.r.t. deep prompts
+
+    while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+        while True:
+            try:
+                inputs = intermediate_inputs.pop(-1)
+                span = forward_sequences.pop(-1)
+
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+
+                grad_outputs = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                )
+                break
+            except Exception as e:
+                logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
+                _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
+                    inputs, sequence_manager, start_index=span.start, end_index=span.end
+                )
+
+                assert len(intermediate_inputs) == len(forward_sequences)
+                assert backup_forward_sequences[0].start == span.start
+                assert backup_forward_sequences[-1].end == span.end
+
+                forward_sequences.extend(backup_forward_sequences)
+                intermediate_inputs.extend(backup_intermediate_inputs)
+    return grad_outputs
+
+
+async def _gather_forward(input_batches, sequence_manager):
+    """Wrapper for asyncio.gather to perform parallel sequential forwards"""
+    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
+
+
+async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+    """Wrapper for asyncio.gather to perform parallel sequential backwards"""
+    return await asyncio.gather(
+        *[
+            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
+            for grad_output, input_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, forward_sequences
+            )
+        ]
+    )
+
+
+class _RemoteSequentialAutogradFunction(torch.autograd.Function):
+    """
+    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
+    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
+    """
+
+    @staticmethod
+    def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
+        batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
+        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+
+        sequence_manager.rpc_info  # lazy init
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
+        assert len(outputs) == len(input_batches)
+
+        output_batches = [output[0] for output in outputs]
+        intemediate_input_batches = [output[1] for output in outputs]
+        sequences_for_batches = [output[2] for output in outputs]
+
+        ctx.sequence_manager = sequence_manager
+        ctx.intemediate_input_batches = intemediate_input_batches
+        ctx.sequences_for_batches = sequences_for_batches
+        return torch.cat(output_batches, dim=0)
+
+    @staticmethod
+    def backward(ctx, grad_outputs: torch.Tensor):
+        intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
+        forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
+        ctx.sequence_manager.rpc_info  # lazy init
+
+        batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
+        grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
+        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
+
+        grad_input_batches = RemoteExpertWorker.run_coroutine(
+            _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
+        )
+        grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
+        grad_inputs = torch.cat(grad_inputs, dim=0)
+        return (grad_inputs, None)

+ 36 - 3
src/server/backend.py

@@ -1,20 +1,50 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Sequence, Tuple
+from queue import Empty
+from typing import Optional, Sequence, Tuple
 
 import torch
+from hivemind import use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
+from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
 
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
 MAX_LENGTH = 2048
 
 
+class InferenceTaskPool(TaskPool):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
+
+    def iterate_minibatches(self, *args, **kwargs):
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+
+        while True:
+            try:
+                logger.debug(f"{self.name} getting next task")
+                task = self.tasks.get(timeout=self.timeout)
+            except Empty:
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
+
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    yield [task]
+            except InvalidStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+
+
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
-    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
@@ -23,7 +53,10 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+        self.inference_pool = InferenceTaskPool(
+            self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
+        )
+        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 24 - 6
src/server/handler.py

@@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
-                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
                                 hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
                             )
@@ -81,6 +84,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -93,7 +99,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
@@ -106,6 +112,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -117,7 +126,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall output
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
@@ -134,6 +143,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its output
         inter_inputs = [inputs]
@@ -154,7 +167,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
             ]
         )
@@ -162,11 +175,16 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+
         uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
         inputs, grads = inputs_and_grads
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its outputs
         inter_inputs = [inputs]
@@ -186,7 +204,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
         ]
         # Split the serialized_grad_inputs for streaming and respond
@@ -227,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 

+ 28 - 6
src/server/server.py

@@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 from src.server.throughput import get_host_throughput
+from src.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -35,7 +36,6 @@ class Server(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
-        device: torch.device,
         num_connection_handlers: int = 8,
         throughput: float,
         update_period: float = 30,
@@ -49,7 +49,7 @@ class Server(threading.Thread):
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
             self.module_backends,
             dht,
@@ -101,10 +101,12 @@ class Server(threading.Thread):
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
-        num_handlers: Optional[int] = None,
+        num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
+        revision: str = "main",
+        cache_dir: Optional[str] = None,
         cache_size_bytes: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
@@ -115,6 +117,7 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         max_block_selection_delay: float = 1,
         use_auth_token: Optional[str] = None,
+        load_in_8bit: bool = False,
         *,
         start: bool,
         **kwargs,
@@ -148,7 +151,9 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
-        block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
+        )
 
         if block_indices is not None:
             try:
@@ -186,7 +191,15 @@ class Server(threading.Thread):
                 block_config,
                 torch_dtype=torch_dtype,
                 use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
             )
+
+            if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
+                block = replace_8bit_linear(block)
+
+            block = block.to(device)
             for param in block.parameters():
                 param.requires_grad = False
 
@@ -194,9 +207,18 @@ class Server(threading.Thread):
                 module_uid,
                 block,
                 memory_cache=memory_cache,
-                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                args_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                outputs_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )

+ 34 - 0
src/utils/convert_8bit.py

@@ -0,0 +1,34 @@
+import bitsandbytes as bnb
+import torch
+
+
+def replace_8bit_linear(model, threshold=6.0):
+    """
+    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
+    be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
+    CPU/GPU memory is required to run this function.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n != "lm_head":
+            model._modules[n] = bnb.nn.Linear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+            ).to(model._modules[n].weight.device)
+    return model

+ 40 - 9
tests/test_full_model.py

@@ -4,6 +4,7 @@ import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
 
+from src.bloom.model import BloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
@@ -13,13 +14,15 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
-    with torch.no_grad():
+    with torch.inference_mode():
         parallel_outputs = model.forward(test_inputs).logits
         assert torch.all(torch.isfinite(parallel_outputs))
         logger.info("Forward outputs are finite")
@@ -32,24 +35,52 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
         recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-        dictionary = model.transformer.word_embeddings.weight.t()
-        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-        recurrent_outputs = (recurrent_outputs @ dictionary).float()
+        recurrent_outputs = model.lm_head(recurrent_outputs)
         assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
         logger.info("Inference is consistent with forward")
 
-        del model, recurrent_outputs
+        del model, embs, recurrent_outputs
 
         if REF_NAME:
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            ref_model = transformers.BloomForCausalLM.from_pretrained(
+                REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+            )
             dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
             # note: this creates a dummy mask to make the test compatible with older transformer versions
             # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
             assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
             logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
             del ref_model, ref_outputs, dummy_mask
         else:
             logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
             assert False
+
+
+@pytest.mark.forked
+def test_greedy_generation(max_new_tokens=4):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    remote_outputs = model.generate(
+        inputs,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
+    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
+
+    inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
+        "input_ids"
+    ]
+    remote_outputs_batch = model.generate(
+        inputs_batch,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs_batch = BloomForCausalLM.greedy_search(
+        model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
+    )
+    assert torch.allclose(
+        remote_outputs_batch, hf_outputs_batch
+    ), "Greedy search are not identical to HF in multibatch mode"