Forráskód Böngészése

Measure and cache throughput by default

Aleksandr Borzunov 3 éve
szülő
commit
c10e505324
4 módosított fájl, 24 hozzáadás és 15 törlés
  1. 7 2
      cli/run_server.py
  2. 5 5
      src/dht_utils.py
  3. 8 3
      src/server/server.py
  4. 4 5
      src/server/throughput.py

+ 7 - 2
cli/run_server.py

@@ -41,8 +41,13 @@ def main():
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
 
-    parser.add_argument('--throughput', type=float, default=1.0,
-                        help='Expected server throughput')
+    parser.add_argument('--throughput',
+                        type=lambda value: value if value in ['auto', 'eval'] else float(value),
+                        default='auto',
+                        help='Expected server throughput (a float measured in RPS). '
+                             'If set to "auto" (default), the script evaluates network and compute throughput '
+                             'on the first run and uses these estimates for future runs. '
+                             'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
     parser.add_argument('--update_period', type=float, required=False, default=30,
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 5 - 5
src/dht_utils.py

@@ -3,6 +3,7 @@ Utilities for declaring and retrieving active model layers using a shared DHT.
 """
 from __future__ import annotations
 
+import math
 from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 
@@ -134,11 +135,10 @@ async def _get_remote_module_infos(
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
-                server_info = server_info.value
-                if not (isinstance(server_info, tuple) and len(server_info) == 2 and
-                        isinstance(server_info[0], int) and isinstance(server_info[1], float)):
-                    raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
-                state, throughput = server_info
+                state, throughput = server_info.value
+                if not (isinstance(state, int) and isinstance(throughput, float) and
+                        math.isfinite(throughput) and throughput >= 0.0):
+                    raise ValueError(f"Invalid server info: {server_info}")
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
             except (TypeError, ValueError) as e:
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

+ 8 - 3
src/server/server.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import multiprocessing as mp
 import threading
-from typing import Dict, Optional, Sequence, Union
+from typing import Dict, Literal, Optional, Sequence, Union
 
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@@ -19,6 +19,7 @@ from src.server.backend import TransformerBackend
 from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
+from src.server.throughput import get_host_throughput
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -95,7 +96,7 @@ class Server(threading.Thread):
         cls,
         prefix: Optional[str],
         converted_model_name_or_path: str,
-        throughput: float,
+        throughput: Union[float, Literal['auto', 'eval']],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
@@ -103,7 +104,7 @@ class Server(threading.Thread):
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
         cache_size_bytes: Optional[int] = None,
-        device: Union[str, torch.device] = None,
+        device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
@@ -136,6 +137,10 @@ class Server(threading.Thread):
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         memory_cache = MemoryCache(device, cache_size_bytes)
 
+        assert isinstance(throughput, float) or throughput in ['auto', 'eval']
+        if throughput in ['auto', 'eval']:
+            throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
+
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"

+ 4 - 5
src/server/throughput.py

@@ -6,7 +6,7 @@ import tempfile
 import time
 from dataclasses import asdict, dataclass
 from pathlib import Path
-from typing import Dict
+from typing import Dict, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -33,19 +33,19 @@ class ThroughputInfo:
 
 
 def get_host_throughput(
-    device: str,
+    device: Union[str, torch.device],
     force_eval: bool = False,
     cache_path: str = DEFAULT_CACHE_PATH,
     lock_path: str = DEFAULT_LOCK_PATH,
 ) -> float:
     # We only keep the device type, assuming that the throughput is similar among all host's GPUs
-    device = device.split(':')[0] if ':' in device else device
+    device = torch.device(device).type
 
     info = None
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     with open(lock_path, 'wb') as lock_fd:
-        logger.info("Waiting for the throughput info")
+        logger.info("Loading throughput info")
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
 
@@ -81,7 +81,6 @@ def measure_throughput_info() -> ThroughputInfo:
     if torch.cuda.is_available():
         device_rps['cuda'] = measure_device_rps('cuda', config)
 
-    logger.info("Use `--throughput eval` if you'd like to re-evaluate the throughput later")
     return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)