Bladeren bron

after style fixes

artek0chumak 3 jaren geleden
bovenliggende
commit
285080f3d8
3 gewijzigde bestanden met toevoegingen van 8 en 10 verwijderingen
  1. 4 4
      src/client/remote_generation.py
  2. 0 2
      src/client/remote_sequential.py
  3. 4 4
      src/utils/generation_constraints.py

+ 4 - 4
src/client/remote_generation.py

@@ -32,16 +32,16 @@ class RemoteGenerationMixin(PreTrainedModel):
                     decoding_algorithm = NucleusAlgorithm(top_p, temperature)
             else:
                 decoding_algorithm = GreedyAlgorithm()
-        
+
         constraints = []
         constraints.extend(provided_constraints)
-                
+
         if max_new_tokens and eos_token_id:
             constraints.append(MaxNewTokensConstraint(max_new_tokens, eos_token_id))
-            
+
         for constraint in constraints:
             constraint.consume_prefix(inputs)
-       
+
         word_embeddings = self.transformer.word_embeddings.weight.t()
 
         with self.transformer.h.inference_session() as sess:

+ 0 - 2
src/client/remote_sequential.py

@@ -5,8 +5,6 @@ import logging
 import random
 from typing import Optional, Union
 
-from typing import Optional
-
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker

+ 4 - 4
src/utils/generation_constraints.py

@@ -6,13 +6,13 @@ from abc import ABC
 class ABConstraint(ABC):
     def __init__(self) -> None:
         pass
-    
+
     def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
         pass
-    
+
     def consume_prefix(self, prefix: torch.Tensor) -> None:
         pass
-    
+
     def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
         pass
     
@@ -26,7 +26,7 @@ class MaxNewTokensConstraint(ABConstraint):
     
     def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
         self.current_generated_tokens += 1
-        
+
     def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
         if self.current_generated_tokens > self.max_new_tokens:
             mask = torch.zeros_like(logits)