artek0chumak 2 年之前
父節點
當前提交
98b962631a

+ 7 - 3
src/petals/client/inference_session.py

@@ -103,7 +103,7 @@ class _ServerInferenceSession:
         else:
             assert len(hypo_ids) == len(new_hidden_states)
             assert hypo_ids.dtype == torch.int64
-            
+
         if attention_mask is None:
             attention_mask = DUMMY
 
@@ -217,7 +217,11 @@ class InferenceSession:
         return self
 
     def step(
-        self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, prompts: Optional[torch.Tensor] = None, **kwargs
+        self,
+        inputs: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        prompts: Optional[torch.Tensor] = None,
+        **kwargs,
     ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
@@ -228,7 +232,7 @@ class InferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
-            
+
         if attention_mask is None:
             attention_mask = DUMMY
 

+ 1 - 3
src/petals/client/remote_generation.py

@@ -179,9 +179,7 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                hidden_state = session.step(
-                    hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
-                )[:, -1]
+                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)

+ 1 - 1
src/petals/client/remote_model.py

@@ -191,7 +191,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
 
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
-        
+
         if attention_mask is None:
             attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 

+ 36 - 8
src/petals/client/sequential_autograd.py

@@ -38,7 +38,9 @@ async def sequential_forward(
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
-    assert isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2, f"{type(attention_masks)}: {attention_masks.ndim}"
+    assert (
+        isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2
+    ), f"{type(attention_masks)}: {attention_masks.ndim}"
 
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
@@ -202,20 +204,33 @@ async def _gather_forward(input_batches, attention_mask_batches, prompt_batches,
     return await asyncio.gather(
         *[
             sequential_forward(input_batch, attention_mask_batch, prompt_batch, sequence_manager)
-            for input_batch, attention_mask_batch, prompt_batch in zip(input_batches, attention_mask_batches, prompt_batches)
+            for input_batch, attention_mask_batch, prompt_batch in zip(
+                input_batches, attention_mask_batches, prompt_batches
+            )
         ]
     )
 
 
 async def _gather_backward(
-    grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences, sequence_manager
+    grad_output_batches,
+    intermediate_input_batches,
+    attention_mask_batches,
+    prompt_batches,
+    forward_sequences,
+    sequence_manager,
 ):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
-            sequential_backward((grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager)
+            sequential_backward(
+                (grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager
+            )
             for grad_output, input_batch, attention_mask_batch, prompt_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences
+                grad_output_batches,
+                intermediate_input_batches,
+                attention_mask_batches,
+                prompt_batches,
+                forward_sequences,
             )
         ]
     )
@@ -228,7 +243,13 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(
+        ctx,
+        inputs: torch.Tensor,
+        attention_mask: torch.Tensor,
+        prompts: torch.Tensor,
+        sequence_manager: RemoteSequenceManager,
+    ):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
         attention_mask_batches: Sequence[torch.Tensor] = attention_mask.detach().split(batch_size)
@@ -238,7 +259,9 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(
+            _gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager)
+        )
         assert len(outputs) == len(input_batches)
 
         output_batches = [output[0] for output in outputs]
@@ -261,7 +284,12 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
 
         batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
-        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) == len(attention_mask_batches)
+        assert (
+            len(intermediate_input_batches)
+            == len(grad_output_batches)
+            == len(forward_sequences)
+            == len(attention_mask_batches)
+        )
 
         outputs = RemoteExpertWorker.run_coroutine(
             _gather_backward(

+ 7 - 3
src/petals/server/handler.py

@@ -134,7 +134,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
                     assert len(cache_handles) == len(requested_backends)
                     while request.tensors:  # iterate while user is willing to supply tensors
-                        hidden_states, attention_mask, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+                        hidden_states, attention_mask, prompts, hypo_ids = map(
+                            deserialize_torch_tensor, request.tensors
+                        )
 
                         # Cast inputs to backend dtype
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -156,9 +158,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
-                            
+
                         if is_dummy(attention_mask):
-                            attention_mask = torch.ones((hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype)
+                            attention_mask = torch.ones(
+                                (hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype
+                            )
 
                         priority = self._prioritizer.prioritize(
                             hidden_states,

+ 1 - 3
src/petals/server/server.py

@@ -420,9 +420,7 @@ class ModuleContainer(threading.Thread):
                         BatchTensorDescriptor(
                             1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
-                        BatchTensorDescriptor(
-                            1, 2048, dtype=backend_dtype, compression=compression
-                        ),
+                        BatchTensorDescriptor(1, 2048, dtype=backend_dtype, compression=compression),
                     ),
                     kwargs_schema={},
                     outputs_schema=(

+ 3 - 1
tests/test_chained_calls.py

@@ -72,7 +72,9 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
         hidden_states = inputs[:, i : i + 1, :]
         for ref_block, cache in zip(ref_blocks, caches):
             with torch.no_grad():
-                hidden_states, new_cache = ref_block.forward(hidden_states, attention_masks[:, :i+1], use_cache=True, layer_past=cache)
+                hidden_states, new_cache = ref_block.forward(
+                    hidden_states, attention_masks[:, : i + 1], use_cache=True, layer_past=cache
+                )
                 new_caches.append(new_cache)
 
         outputs_ref.append(hidden_states)