Procházet zdrojové kódy

integrate mixed-8bit model (#39)

* integrate mixed-8bit model
* Fix bug with model duplication in RAM
* set throughput=1.0 to fix zero throughput problem
* add revision support
* update hivemind and bitsandbytes
* update deploy scripts
* update installation instructions
Dmitry Baranchuk před 3 roky
rodič
revize
11a424837f

+ 7 - 0
.github/workflows/run-tests.yaml

@@ -62,6 +62,13 @@ jobs:
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes cpuonly
+        run: |
+          git clone https://github.com/TimDettmers/bitsandbytes.git
+          cd bitsandbytes
+          make cpuonly
+          pip install .
+          cd -
       - name: Test
         run: |
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")

+ 1 - 0
README.md

@@ -13,6 +13,7 @@ Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/1
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
 pip install -r requirements.txt
+pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 ```
 
 

+ 9 - 4
cli/deploy_server.sh

@@ -5,7 +5,8 @@
 #################
 
 instructions() {
-  echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo " -m: model name"
   echo " -i: initial peer"
   echo " -d: device" >&2
   echo " -p: server identity path" >&2
@@ -19,8 +20,10 @@ if [ ! $# -ge 8 ]; then
     instructions
 fi
 
-while getopts ":i:d:p:b:a:t:" option; do
+while getopts ":m:i:d:p:b:a:t:" option; do
     case $option in
+        m)  MODEL_NAME=${OPTARG}
+            ;;
         i)  INITIAL_PEER=${OPTARG}
             ;;
         d)  DEVICE=${OPTARG}
@@ -42,6 +45,7 @@ done
 echo "=========="
 echo "= Config ="
 echo "=========="
+echo "Model name: ${MODEL_NAME}"
 echo "Initial peer: ${INITIAL_PEER}"
 echo "Device: ${DEVICE}"
 echo "Server name: ${SERVER_ID_PATH}"
@@ -64,11 +68,12 @@ else
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
     pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
     pip install -i https://pypi.org/simple -r requirements.txt
+    pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 fi
 
 ##############
 # Run server #
 ##############
 
-python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
-  --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
+python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
+  --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log

+ 3 - 3
cli/run_local_servers.sh

@@ -41,6 +41,7 @@ else
     conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
     pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
     pip install -i https://pypi.org/simple -r requirements.txt
+    pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 fi
 
 
@@ -49,7 +50,7 @@ fi
 #######################
 
 hivemind-dht &> tmp.out &
-sleep 3
+sleep 5
 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
 echo "Initial peer: ${INITIAL_PEER}"
 
@@ -96,10 +97,9 @@ do
     # Run server #
     ##############
 
-    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
 done
 
-
 #####################
 # Kill initial peer #
 #####################

+ 7 - 1
cli/run_server.py

@@ -27,12 +27,14 @@ def main():
 
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
 
-    parser.add_argument('--num_handlers', type=int, default=16, required=False,
+    parser.add_argument('--num_handlers', type=int, default=8, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--cache_dir', type=str, default=None, 
+                        help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--cache_size_bytes', type=int, default=None,
                         help='The size of memory cache for storing past attention keys/values between inference steps')
     parser.add_argument('--device', type=str, default=None, required=False,
@@ -40,6 +42,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
+    parser.add_argument('--revision', type=str, default='main',
+                        help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
+                             "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
                         type=lambda value: value if value in ['auto', 'eval'] else float(value),
@@ -64,6 +69,7 @@ def main():
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+    parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 2 - 3
requirements.txt

@@ -1,6 +1,5 @@
 torch==1.12.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
-bitsandbytes-cuda113==0.26.0
-https://github.com/learning-at-home/hivemind/archive/28261470e44f2ae4157d08b563b4d2771f3a9549.zip
-https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
+https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
+https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 9 - 3
src/bloom/from_pretrained.py

@@ -34,12 +34,15 @@ def load_pretrained_block(
     config: Optional[BloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
+    state_dict = _load_state_dict(
+        converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
     block.load_state_dict(state_dict)
 
     if torch_dtype == "auto":
@@ -57,7 +60,10 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
+    pretrained_model_name_or_path: str,
+    block_index: Optional[int] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@@ -65,7 +71,7 @@ def _load_state_dict(
     # Load from URL or cache if already cached
     resolved_archive_file = cached_path(
         archive_file,
-        cache_dir=None,
+        cache_dir=cache_dir,
         force_download=FORCE_DOWNLOAD,
         proxies=None,
         resume_download=RESUME_DOWNLOAD,

+ 3 - 4
src/bloom/model.py

@@ -156,9 +156,7 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
 
         # Embedding + LN Embedding
-
-        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  # dtype=config.torch_dtype
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
         # Transformer blocks
@@ -229,7 +227,8 @@ class BloomModel(BloomPreTrainedModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
 
         output_shape = input_shape + (hidden_states.size(-1),)
 

+ 1 - 1
src/client/inference_session.py

@@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
+                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
                 )

+ 2 - 1
src/client/remote_model.py

@@ -90,7 +90,8 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
         hidden_states = self.h(hidden_states)
 

+ 5 - 2
src/server/handler.py

@@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
-                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
                                 hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
                             )
@@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 

+ 17 - 4
src/server/server.py

@@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 from src.server.throughput import get_host_throughput
+from src.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -35,7 +36,6 @@ class Server(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
-        device: torch.device,
         num_connection_handlers: int = 8,
         throughput: float,
         update_period: float = 30,
@@ -49,7 +49,7 @@ class Server(threading.Thread):
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
             self.module_backends,
             dht,
@@ -101,10 +101,12 @@ class Server(threading.Thread):
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
-        num_handlers: Optional[int] = None,
+        num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
+        revision: str = "main",
+        cache_dir: Optional[str] = None,
         cache_size_bytes: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
@@ -115,6 +117,7 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         max_block_selection_delay: float = 1,
         use_auth_token: Optional[str] = None,
+        load_in_8bit: bool = False,
         *,
         start: bool,
         **kwargs,
@@ -148,7 +151,9 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
-        block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
+        )
 
         if block_indices is not None:
             try:
@@ -186,7 +191,15 @@ class Server(threading.Thread):
                 block_config,
                 torch_dtype=torch_dtype,
                 use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
             )
+
+            if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
+                block = replace_8bit_linear(block)
+
+            block = block.to(device)
             for param in block.parameters():
                 param.requires_grad = False
 

+ 34 - 0
src/utils/convert_8bit.py

@@ -0,0 +1,34 @@
+import bitsandbytes as bnb
+import torch
+
+
+def replace_8bit_linear(model, threshold=6.0):
+    """
+    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
+    be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
+    CPU/GPU memory is required to run this function.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n != "lm_head":
+            model._modules[n] = bnb.nn.Linear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+            ).to(model._modules[n].weight.device)
+    return model