Преглед на файлове

Fix beam search in GPU clients (#531)

Fixes #503.
Alexander Borzunov преди 1 година
родител
ревизия
82a97d6e9e
променени са 2 файла, в които са добавени 24 реда и са изтрити 36 реда
  1. 9 13
      .github/workflows/run-tests.yaml
  2. 15 23
      src/petals/client/inference_session.py

+ 9 - 13
.github/workflows/run-tests.yaml

@@ -48,7 +48,6 @@ jobs:
           export MODEL_NAME="${{ matrix.model }}"
           export REF_NAME="${{ matrix.model }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
-          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
           # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
@@ -61,27 +60,25 @@ jobs:
 
           until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
-            --mean_balance_check_period 10 \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
+          export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
+            --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
+          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
+
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
           SERVER1_PID=$!
           # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
 
           sleep 10  # wait for the 1st server to choose blocks
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
-            --identity_path tests/server2.id \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
           SERVER2_PID=$!
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
-            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
+            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
           SERVER3_PID=$!
           # ^-- chunking test
 
-          python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
+          $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
           SERVER4_PID=$!
           # ^-- tensor parallelism test (not compatible with adapters yet)
 
@@ -121,4 +118,3 @@ jobs:
           # [Step 4] Clean up
 
           kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
-          echo "Done!"

+ 15 - 23
src/petals/client/inference_session.py

@@ -84,12 +84,7 @@ class _ServerInferenceSession:
                 break  # this message means "done sending"
 
     def step(
-        self,
-        inputs: torch.Tensor,
-        prompts: Optional[torch.Tensor] = None,
-        hypo_ids: Optional[torch.Tensor] = None,
-        *,
-        step_id: str,
+        self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -114,21 +109,6 @@ class _ServerInferenceSession:
         else:
             inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
-        if prompts is None or is_dummy(prompts):
-            prompts = DUMMY
-        else:
-            assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
-            assert prompts.shape[0] == self.num_blocks
-            assert prompts.shape[1] in (inputs.shape[0], 1)
-            assert prompts.shape[2] <= inputs.shape[1]
-            assert prompts.shape[3] == inputs.shape[2]
-
-        if hypo_ids is None or is_dummy(hypo_ids):
-            hypo_ids = DUMMY_INT64
-        else:
-            assert len(hypo_ids) == len(inputs)
-            assert hypo_ids.dtype == torch.int64
-
         # serialize inputs and put them into the queue
         input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
 
@@ -275,7 +255,9 @@ class InferenceSession:
         assert not self._closed and not self._server_sessions
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+    def step(
+        self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -285,11 +267,21 @@ class InferenceSession:
         else:
             assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (inputs.shape[0], 1)
+            assert prompts.shape[2] <= inputs.shape[1]
+            assert prompts.shape[3] == inputs.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY_INT64
+        else:
+            assert len(hypo_ids) == len(inputs)
+            assert hypo_ids.dtype == torch.int64
 
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         prompts = prompts.cpu()
+        hypo_ids = hypo_ids.cpu()
         step_id = str(uuid.uuid4())
 
         n_input_tokens = inputs.shape[1]
@@ -310,7 +302,7 @@ class InferenceSession:
 
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
-                        inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
+                        inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
                     )
 
                     server_idx += 1