Răsfoiți Sursa

black-isort

justheuristic 2 ani în urmă
părinte
comite
28971dcedd
3 a modificat fișierele cu 19 adăugiri și 8 ștergeri
  1. 7 4
      src/client/inference_session.py
  2. 6 2
      src/server/backend.py
  3. 6 2
      src/server/handler.py

+ 7 - 4
src/client/inference_session.py

@@ -71,9 +71,12 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self,
-             new_hidden_states: torch.Tensor,
-             prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None):
+    def step(
+        self,
+        new_hidden_states: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+    ):
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@@ -193,7 +196,7 @@ class RemoteSequentialInferenceSession:
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
         for session in self.inference_sessions:
-            outputs = session.step(inputs, prompts[self.chosen_spans[0].start: self.chosen_spans[0].end], **kwargs)
+            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs

+ 6 - 2
src/server/backend.py

@@ -70,7 +70,9 @@ class TransformerBackend(ModuleBackend):
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
             (hidden_states, hypo_ids) = inputs
-            assert (hidden_states.ndim == 3), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+            assert (
+                hidden_states.ndim == 3
+            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
@@ -78,7 +80,9 @@ class TransformerBackend(ModuleBackend):
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
-                hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+                hidden_states, (new_k, new_v) = self.module.forward(
+                    hidden_states, layer_past=layer_past, use_cache=True
+                )
 
                 # todo remove these asserts once we pass all tests
                 new_length = new_v.shape[1]

+ 6 - 2
src/server/handler.py

@@ -95,8 +95,12 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert isinstance(
                             hidden_states, torch.Tensor
                         ), f"hidden states must be tensor, got {type(hidden_states)}"
-                        assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
+                        assert (
+                            hidden_states.ndim == 3
+                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                        (hidden_states,) = await backend.inference_pool.submit_task(
+                            cache_metadata, hidden_states, hypo_ids
+                        )
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(