artek0chumak %!s(int64=2) %!d(string=hai) anos
pai
achega
a22ecc524d

+ 2 - 2
src/petals/client/inference_session.py

@@ -103,7 +103,7 @@ class _ServerInferenceSession:
         else:
             assert len(hypo_ids) == len(new_hidden_states)
             assert hypo_ids.dtype == torch.int64
-            
+
         if attention_mask is None:
             attention_mask = DUMMY
 
@@ -235,7 +235,7 @@ class InferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
-            
+
         if attention_mask is None:
             attention_mask = DUMMY
 

+ 1 - 3
src/petals/client/remote_generation.py

@@ -179,9 +179,7 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                hidden_state = session.step(
-                    hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
-                )[:, -1]
+                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)

+ 1 - 1
src/petals/client/remote_model.py

@@ -191,7 +191,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
 
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
-        
+
         if attention_mask is None:
             attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 

+ 4 - 2
src/petals/server/handler.py

@@ -158,9 +158,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
-                            
+
                         if is_dummy(attention_mask):
-                            attention_mask = torch.ones((hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype)
+                            attention_mask = torch.ones(
+                                (hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype
+                            )
 
                         if is_dummy(attention_mask):
                             attention_mask = torch.ones(