Browse Source

WIP try loading OPT-sized layers using BLOOM code

Max Ryabinin 2 years ago
parent
commit
23e3f9a332
2 changed files with 21 additions and 26 deletions
  1. 0 20
      src/petals/bloom/from_pretrained.py
  2. 21 6
      src/petals/server/server.py

+ 0 - 20
src/petals/bloom/from_pretrained.py

@@ -45,26 +45,6 @@ def load_pretrained_block(
         cache_dir = DEFAULT_CACHE_DIR
 
     block = WrappedBloomBlock(config)
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-        max_disk_space=max_disk_space,
-    )
-
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
-    report = block.load_state_dict(state_dict, strict=True)
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
     return block
 
 

+ 21 - 6
src/petals/server/server.py

@@ -15,7 +15,7 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
-from transformers import BloomConfig
+from transformers import BloomConfig, OPTConfig
 
 from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.constants import PUBLIC_INITIAL_PEERS
@@ -112,11 +112,26 @@ class Server:
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
-        self.block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path,
-            use_auth_token=use_auth_token,
-            revision=revision,
-        )
+        if "bloom" in converted_model_name_or_path:
+            self.block_config = BloomConfig.from_pretrained(
+                converted_model_name_or_path,
+                use_auth_token=use_auth_token,
+                revision=revision,
+            )
+        else:
+            def _patch_bloom_config(bloom_config: BloomConfig, opt_config: OPTConfig):
+                bloom_config.hidden_size = opt_config.hidden_size
+                bloom_config.n_head = opt_config.num_attention_heads
+                bloom_config.n_layer = opt_config.num_hidden_layers
+                bloom_config.vocab_size = opt_config.vocab_size
+
+            opt_config = OPTConfig.from_pretrained(converted_model_name_or_path)
+            bloom_config = BloomConfig.from_pretrained(
+                "bigscience/bloom-petals"
+            )
+            _patch_bloom_config(bloom_config, opt_config)
+            self.block_config = bloom_config
+
         self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
         if dht_client_mode is None: