Browse Source

black-isort

justheuristic 3 years ago
parent
commit
88c1bf9896
4 changed files with 16 additions and 16 deletions
  1. 1 1
      src/bloom/__init__.py
  2. 1 7
      src/bloom/model.py
  3. 1 1
      src/client/__init__.py
  4. 13 7
      src/client/remote_model.py

+ 1 - 1
src/bloom/__init__.py

@@ -1,2 +1,2 @@
 from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig, BloomModel, BloomPreTrainedModel, BloomForCausalLM
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

+ 1 - 7
src/bloom/model.py

@@ -359,13 +359,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
         output_type=CausalLMOutputWithCrossAttentions,
         config_class=_CONFIG_FOR_DOC,
     )
-    def forward(
-        self,
-        input_ids=None,
-        labels=None,
-        return_dict=None,
-        **kwargs
-    ):
+    def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
         r"""
         labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
             Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
-from src.client.remote_model import DistributedBloomConfig, DistributedBloomModel, DistributedBloomForCausalLM
+from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequence_info import RemoteSequenceInfo
 from src.client.remote_sequential import RemoteSequential

+ 13 - 7
src/client/remote_model.py

@@ -1,13 +1,12 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Union, Tuple
+from typing import Optional, Tuple, Union
 
 import hivemind
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
-from src.bloom.model import BloomModel, BloomForCausalLM, BloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
-from src.bloom.model import BloomPreTrainedModel
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -20,6 +19,7 @@ class DistributedBloomConfig(BloomConfig):
     A bloom config that contains information about DHT peers.
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
     """
+
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
@@ -27,26 +27,32 @@ class DistributedBloomConfig(BloomConfig):
 
 class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
+
     def __init__(self, config: DistributedBloomConfig):
         assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert self.config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+        assert (
+            self.config.initial_peers or config.dht
+        ), "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
 
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
         assert len(self.h) == 0
         config.n_layer = n_layer
 
-        dht = config.dht if config.dht is not None else hivemind.DHT(
-            initial_peers=config.initial_peers, client_mode=True, start=True)
+        dht = (
+            config.dht
+            if config.dht is not None
+            else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+        )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix)
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel().__init__(config)
         self.transformer = DistributedBloomModel(config)
         # Initialize weights and apply final processing
         self.post_init()
-