Explorar o código

re-apply black, isort

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
f0081edc30
Modificáronse 2 ficheiros con 24 adicións e 16 borrados
  1. 1 1
      src/bloom/model.py
  2. 23 15
      src/client/remote_model.py

+ 1 - 1
src/bloom/model.py

@@ -562,7 +562,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
                     f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                     f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                     "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                     "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                 )
                 )
-            
+
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
         loss = None
         loss = None

+ 23 - 15
src/client/remote_model.py

@@ -2,13 +2,19 @@
 import os
 import os
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
+import hivemind
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
-
-import hivemind
 from hivemind import get_logger, use_hivemind_log_handler
 from hivemind import get_logger, use_hivemind_log_handler
 
 
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
+from src.bloom.model import (
+    BloomConfig,
+    BloomForCausalLM,
+    BloomForSequenceClassification,
+    BloomModel,
+    BloomPreTrainedModel,
+    LMHead,
+)
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
 
 
@@ -25,8 +31,8 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
-    num_prefix_tokens: int = 0 # a number of tokens for prompt tuning. 
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
+    num_prefix_tokens: int = 0  # a number of tokens for prompt tuning.
 
 
 
 
 class DistributedBloomModel(BloomModel):
 class DistributedBloomModel(BloomModel):
@@ -77,7 +83,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return prompts
         return prompts
 
 
     def forward(
     def forward(
-        self, 
+        self,
         input_ids: Optional[torch.LongTensor],
         input_ids: Optional[torch.LongTensor],
         inputs_embeds: Optional[torch.Tensor],
         inputs_embeds: Optional[torch.Tensor],
         attention_mask: Optional[torch.Tensor],
         attention_mask: Optional[torch.Tensor],
@@ -87,14 +93,16 @@ class DistributedBloomPrefix(DistributedBloomModel):
         use_cache=None,
         use_cache=None,
         output_attentions=None,
         output_attentions=None,
         output_hidden_states=None,
         output_hidden_states=None,
-        return_dict=None
+        return_dict=None,
     ):
     ):
-        assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert (
+            input_ids is None or inputs_embeds is None
+        ), "You cannot specify both input_ids and inputs_embeds at the same time"
         assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
         assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-        
+
         if inputs_embeds is None:
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
             inputs_embeds = self.word_embeddings(input_ids)
-    
+
         batch_size = inputs_embeds.shape[0]
         batch_size = inputs_embeds.shape[0]
 
 
         if attention_mask is not None:
         if attention_mask is not None:
@@ -105,20 +113,20 @@ class DistributedBloomPrefix(DistributedBloomModel):
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
 
         transformer_outputs = super().forward(
         transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds, 
-            attention_mask=attention_mask, 
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
             past_key_values=past_key_values,
             past_key_values=past_key_values,
             position_ids=position_ids,
             position_ids=position_ids,
             head_mask=head_mask,
             head_mask=head_mask,
             use_cache=use_cache,
             use_cache=use_cache,
             output_attentions=output_attentions,
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             output_hidden_states=output_hidden_states,
-            return_dict=return_dict
+            return_dict=return_dict,
         )
         )
 
 
         # Remove prefix
         # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
-        transformer_outputs['last_hidden_state'] = last_hidden_state
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
+        transformer_outputs["last_hidden_state"] = last_hidden_state
         return transformer_outputs
         return transformer_outputs