|
@@ -15,7 +15,7 @@ from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
-from transformers import BloomConfig
|
|
|
+from transformers import BloomConfig, OPTConfig
|
|
|
|
|
|
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
@@ -112,11 +112,26 @@ class Server:
|
|
|
self.request_timeout = request_timeout
|
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
|
|
- self.block_config = BloomConfig.from_pretrained(
|
|
|
- converted_model_name_or_path,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- revision=revision,
|
|
|
- )
|
|
|
+ if "bloom" in converted_model_name_or_path:
|
|
|
+ self.block_config = BloomConfig.from_pretrained(
|
|
|
+ converted_model_name_or_path,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
+ revision=revision,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ def _patch_bloom_config(bloom_config: BloomConfig, opt_config: OPTConfig):
|
|
|
+ bloom_config.hidden_size = opt_config.hidden_size
|
|
|
+ bloom_config.n_head = opt_config.num_attention_heads
|
|
|
+ bloom_config.n_layer = opt_config.num_hidden_layers
|
|
|
+ bloom_config.vocab_size = opt_config.vocab_size
|
|
|
+
|
|
|
+ opt_config = OPTConfig.from_pretrained(converted_model_name_or_path)
|
|
|
+ bloom_config = BloomConfig.from_pretrained(
|
|
|
+ "bigscience/bloom-petals"
|
|
|
+ )
|
|
|
+ _patch_bloom_config(bloom_config, opt_config)
|
|
|
+ self.block_config = bloom_config
|
|
|
+
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
if dht_client_mode is None:
|