justheuristic 3 лет назад
Родитель
Сommit
19ae71e8fc
2 измененных файлов с 29 добавлено и 3 удалено
  1. 28 2
      src/client/remote_model.py
  2. 1 1
      src/client/remote_sequential.py

+ 28 - 2
src/client/remote_model.py

@@ -1,9 +1,12 @@
 # this code is in active development, interfaces may change
-from typing import Optional
+import os
+from typing import Optional, Union
 
+import hivemind
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
 from src.bloom import BloomForCausalLM, DistributedBloomConfig
+from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -14,9 +17,32 @@ logger = get_logger(__file__)
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None):
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
         assert len(self.transformer.h) == 0
         config.n_layer = n_layer
         self.transformer.h = RemoteSequential(config, dht, prefix)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
+        assert 'initial_peers' in kwargs
+        dht = hivemind.DHT(
+            initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
+            start=True)
+
+        if 'prefix' not in kwargs:
+            logger.warning(f"No prefix specified; setting prefix to {pretrained_model_name_or_path}")
+            assert UID_DELIMITER not in pretrained_model_name_or_path, \
+                f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
+        prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
+
+        config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
+        model = cls(config, dht, prefix)
+        model.load_state_dict(_load_state_dict(
+            pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
+        ), strict=True)
+        return model
+
+
+

+ 1 - 1
src/client/remote_sequential.py

@@ -20,7 +20,7 @@ class RemoteSequential(nn.Sequential):
     A sequence of transformer blocks hosted by the swarm.
     """
 
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3):
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         if prefix.endswith(UID_DELIMITER):
             logger.warning(