Pārlūkot izejas kodu

Add dry_run option to --throughput

Max Ryabinin 1 gadu atpakaļ
vecāks
revīzija
b2ab84cc33
2 mainītis faili ar 11 papildinājumiem un 5 dzēšanām
  1. 3 2
      src/petals/cli/run_server.py
  2. 8 3
      src/petals/server/server.py

+ 3 - 2
src/petals/cli/run_server.py

@@ -106,12 +106,13 @@ def main():
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
-                        type=lambda value: value if value in ['auto', 'eval'] else float(value),
+                        type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
                         default='auto',
                         help='Expected server throughput (a float measured in RPS). '
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
-                             'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
+                             'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
+                             'If set to "dry_run", the script re-evaluates the throughput and exits.')
     parser.add_argument('--update_period', type=float, required=False, default=120,
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 8 - 3
src/petals/server/server.py

@@ -5,6 +5,7 @@ import math
 import multiprocessing as mp
 import os
 import random
+import sys
 import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
@@ -234,8 +235,9 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
-        if throughput in ["auto", "eval"]:
+        assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
+        if throughput in ["auto", "eval", "dry_run"]:
+            force_eval = throughput in ["eval", "dry_run"]
             throughput_info = get_server_throughput(
                 converted_model_name_or_path,
                 self.block_config,
@@ -245,9 +247,12 @@ class Server:
                 quant_type=quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 reachable_via_relay=reachable_via_relay,
-                force_eval=(throughput == "eval"),
+                force_eval=force_eval,
                 cache_dir=cache_dir,
             )
+            if throughput == "dry_run":
+                logger.info("Finished estimating throughput, exiting")
+                sys.exit(0)
         else:
             throughput_info = {"throughput": throughput}
         self.server_info = ServerInfo(