浏览代码

integrate mixed-8bit model

dbaranchuk 3 年之前
父节点
当前提交
a621549c06
共有 5 个文件被更改,包括 65 次插入8 次删除
  1. 15 4
      cli/deploy_server.sh
  2. 2 3
      cli/run_local_servers.sh
  3. 3 0
      cli/run_server.py
  4. 11 1
      src/server/server.py
  5. 34 0
      src/utils/convert_8bit.py

+ 15 - 4
cli/deploy_server.sh

@@ -5,7 +5,8 @@
 #################
 
 instructions() {
-  echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+  echo " -m: model name"
   echo " -i: initial peer"
   echo " -d: device" >&2
   echo " -p: server identity path" >&2
@@ -19,8 +20,10 @@ if [ ! $# -ge 8 ]; then
     instructions
 fi
 
-while getopts ":i:d:p:b:a:t:" option; do
+while getopts ":m:i:d:p:b:a:t:" option; do
     case $option in
+        m)  MODEL_NAME=${OPTARG}
+            ;;
         i)  INITIAL_PEER=${OPTARG}
             ;;
         d)  DEVICE=${OPTARG}
@@ -42,6 +45,7 @@ done
 echo "=========="
 echo "= Config ="
 echo "=========="
+echo "Model name: ${MODEL_NAME}"
 echo "Initial peer: ${INITIAL_PEER}"
 echo "Device: ${DEVICE}"
 echo "Server name: ${SERVER_ID_PATH}"
@@ -70,5 +74,12 @@ fi
 # Run server #
 ##############
 
-python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
-  --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
+# [('NONE', 0),
+#  ('MEANSTD_16BIT', 1),
+#  ('FLOAT16', 2),
+#  ('QUANTILE_8BIT', 3),
+#  ('UNIFORM_8BIT', 4),
+#  ('BLOCKWISE_8BIT', 5)]
+
+python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
+  --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit # &> ${SERVER_ID_PATH}.log

+ 2 - 3
cli/run_local_servers.sh

@@ -49,7 +49,7 @@ fi
 #######################
 
 hivemind-dht &> tmp.out &
-sleep 3
+sleep 5
 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
 echo "Initial peer: ${INITIAL_PEER}"
 
@@ -96,10 +96,9 @@ do
     # Run server #
     ##############
 
-    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
 done
 
-
 #####################
 # Kill initial peer #
 #####################

+ 3 - 0
cli/run_server.py

@@ -33,6 +33,8 @@ def main():
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--cache_dir', type=str, default=None, 
+                        help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--cache_size_bytes', type=int, default=None,
                         help='The size of memory cache for storing past attention keys/values between inference steps')
     parser.add_argument('--device', type=str, default=None, required=False,
@@ -64,6 +66,7 @@ def main():
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+    parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 11 - 1
src/server/server.py

@@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 from src.server.throughput import get_host_throughput
+from src.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -105,6 +106,7 @@ class Server(threading.Thread):
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
+        cache_dir: Optional[str] = None,
         cache_size_bytes: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
@@ -115,6 +117,7 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         max_block_selection_delay: float = 1,
         use_auth_token: Optional[str] = None,
+        load_in_8bit: bool = False,
         *,
         start: bool,
         **kwargs,
@@ -148,7 +151,9 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
-        block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token, revision="client"
+        )
 
         if block_indices is not None:
             try:
@@ -186,7 +191,12 @@ class Server(threading.Thread):
                 block_config,
                 torch_dtype=torch_dtype,
                 use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
             )
+
+            if load_in_8bit:
+                block = replace_8bit_linear(block)
+
             for param in block.parameters():
                 param.requires_grad = False
 

+ 34 - 0
src/utils/convert_8bit.py

@@ -0,0 +1,34 @@
+import torch
+import bitsandbytes as bnb
+
+
+def replace_8bit_linear(model, threshold=6.0):
+    """
+    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
+    be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
+    CPU/GPU memory is required to run this function.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n != "lm_head":
+            model._modules[n] = bnb.nn.Linear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+            ).to(model._modules[n].weight.device)
+    return model