Sfoglia il codice sorgente

fetch a specific bloom block without downloading the entire model

justheuristic 3 anni fa
parent
commit
1ab5fb1630
7 ha cambiato i file con 44 aggiunte e 29 eliminazioni
  1. 6 6
      README.md
  2. 6 2
      cli/run_server.py
  3. 18 9
      src/bloom/from_pretrained.py
  4. 1 1
      src/bloom/model.py
  5. 2 2
      src/server/backend.py
  6. 1 1
      src/server/cache.py
  7. 10 8
      src/server/server.py

+ 6 - 6
README.md

@@ -37,8 +37,8 @@ python -m cli.inference_one_block --config cli/config.json  # see other args
 First, run one or more servers like this:
 ```bash
 # minimalistic server with non-trained bloom blocks
-python -m cli.run_server --prefix bloom6b3 --block_config bigscience/bloom-6b3 --num_blocks 2 \
-  --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
+python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 \
+  --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
 # when running multiple servers:
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
@@ -57,15 +57,15 @@ dht = hivemind.DHT(
     client_mode=True, start=True,
 )
 
-layer0, layer1 = get_remote_module(dht, ['bloom6b3.0', 'bloom6b3.1'])
-
+layer3, layer4 = get_remote_module(dht, ['bloom6b3.3', 'bloom6b3.4'])
+assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
 # test forward/backward, two blocks
-outputs, = layer1(*layer0(torch.randn(1, 64, 4096)))
+outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer0.begin_inference_session() as sess:
+with layer3.begin_inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 6 - 2
cli/run_server.py

@@ -6,7 +6,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from src.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__name__)
+logger = get_logger(__file__)
 
 
 def main():
@@ -15,7 +15,8 @@ def main():
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
     parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
-    parser.add_argument('--block_config', type=str, default='bigscience/bloom-6b3', help="name or path of model config")
+    parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
+                        help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
@@ -35,6 +36,9 @@ def main():
                         help='The size of memory cache for storing past attention keys/values between inference steps')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
+    parser.add_argument("--torch_dtype", type=str, default="auto",
+                        help="Use this dtype to store block weights and do computations. "
+                             "By default, respect the dtypes in the pre-trained state dict.")
 
     parser.add_argument('--update_period', type=float, required=False, default=30,
                         help='Server will report experts to DHT once in this many seconds')

+ 18 - 9
src/bloom/from_pretrained.py

@@ -1,12 +1,14 @@
 """
-Utils for fetching pre-trained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
 If necessary, one can rewrite this to implement a different behavior, such as:
  - loading files from a local data source (e.g. S3)
  - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
  - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
 
 """
-from typing import Optional, OrderedDict
+from __future__ import annotations
+
+from typing import Optional, OrderedDict, Union
 
 import torch
 from hivemind.utils.logging import use_hivemind_log_handler, get_logger
@@ -28,16 +30,24 @@ LOCAL_FILES_ONLY = False
 
 
 def load_pretrained_block(
-        converted_model_name_or_path: str, block_index: int, config: Optional[DistributedBloomConfig] = None):
+        converted_model_name_or_path: str, block_index: int,
+        config: Optional[DistributedBloomConfig] = None, torch_dtype: Union[torch.dtype, str] = 'auto') -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
     block = BloomBlock(config, layer_number=block_index)
     state_dict = _load_state_dict(converted_model_name_or_path, block_index)
-    with torch.no_grad():
-        for name, param in block.named_parameters():
-            assert name in state_dict, f"{name} not in state dict"
-            param.data = param.data.to(state_dict[name].dtype)
+    block.load_state_dict(state_dict)
+
+    if torch_dtype == 'auto':
+        with torch.no_grad():
+            for name, param in block.named_parameters():
+                assert name in state_dict, f"{name} not in state dict"
+                param.data = param.data.to(state_dict[name].dtype)
+    else:
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        block = block.to(dtype=torch_dtype)
+
     report = block.load_state_dict(state_dict, strict=True)
     logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
     return block
@@ -63,5 +73,4 @@ def _load_state_dict(
     return state_dict
 
 
-
-
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 1 - 1
src/bloom/model.py

@@ -28,7 +28,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__file__)
 
 _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
-_CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
+_CONFIG_FOR_DOC = "DistributedBloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 

+ 2 - 2
src/server/backend.py

@@ -5,7 +5,7 @@ import torch
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 
-from src.bloom.block import BloomBlock
+from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
 
 MAX_LENGTH = 2048
@@ -38,7 +38,7 @@ class TransformerBackend(ModuleBackend):
             print(past_k.shape, past_v.shape)
             hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 
-            # todo remove these debugprints
+            # todo remove these asserts once we pass all tests
             new_length = new_v.shape[1]
             assert new_length > prefix_length
             assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]

+ 1 - 1
src/server/cache.py

@@ -16,7 +16,7 @@ from hivemind import use_hivemind_log_handler
 from hivemind.utils import TensorDescriptor, get_logger
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__name__)
+logger = get_logger(__file__)
 
 Handle = int
 

+ 10 - 8
src/server/server.py

@@ -12,12 +12,12 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import BloomForCausalLM, DistributedBloomConfig
-from src.bloom.block import BloomBlock
+from src.bloom.from_pretrained import load_pretrained_block, DistributedBloomConfig, DTYPE_MAP
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
@@ -83,12 +83,13 @@ class Server(threading.Thread):
     def create(
         cls,
         prefix: str,
-        block_config: str,
+        converted_model_name_or_path: str,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
+        torch_dtype: str = 'auto',
         cache_size_bytes: Optional[int] = None,
         device: Union[str, torch.device] = None,
         initial_peers: Sequence[str] = (),
@@ -112,6 +113,10 @@ class Server(threading.Thread):
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         memory_cache = MemoryCache(device, cache_size_bytes)
 
+        if isinstance(torch_dtype, str):
+            torch_dtype = DTYPE_MAP[torch_dtype]
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+
         if block_indices is not None:
             try:
                 start, end = block_indices.split(":")
@@ -124,16 +129,13 @@ class Server(threading.Thread):
             assert num_blocks is not None
             block_indices = range(num_blocks)  # TODO replace with proper load balancing
 
-        ## TODO: the code below will load the entire model in RAM. Please replace with sliced model
-        block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
-        # model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
-        ## /TODO
+        block_config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=True)
 
         # initialize modules
         blocks = {}
         for block_index in block_indices:
             module_uid = f"{prefix}.{block_index}"
-            block = BloomBlock(block_config, layer_number=block_index)
+            block = load_pretrained_block(converted_model_name_or_path, block_index, block_config, torch_dtype)
             for param in block.parameters():
                 param.requires_grad = False