Bläddra i källkod

Fix routing through relay, default network RPS, --token, logging, readme (#399)

* Hide GeneratorExit in _iterate_inference_steps()
* Update README.md about `--public_name`
* Use .from_pretrained(..., use_auth_token=token) instead of token=token
until it's fully supported across HF libs
* Use default network speed 25 Mbit/s
* Apply relay penalty in max-throughput routing
* Replace RPS with "tokens/sec per block" in logs
* Increase default expiration
Alexander Borzunov 2 år sedan
förälder
incheckning
8666653cf5

+ 6 - 2
README.md

@@ -34,11 +34,13 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
 
 ### Connect your GPU and increase Petals capacity
 
+Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
+
 Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):
 
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install --upgrade petals
+pip install git+https://github.com/bigscience-workshop/petals
 python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
 ```
 
@@ -55,6 +57,8 @@ This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.c
 
 💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!
 
+🏆 If you host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks! You can specify them with `--public_name YOUR_NAME`. We will show them once your server loads all blocks.
+
 ### Check out tutorials, examples, and more
 
 Basic tutorials:
@@ -97,7 +101,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d
 
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install --upgrade petals
+pip install git+https://github.com/bigscience-workshop/petals
 ```
 
 If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).

+ 10 - 2
src/petals/client/routing/sequence_manager.py

@@ -291,7 +291,9 @@ class RemoteSequenceManager:
         # This is okay since false positives are more costly than false negatives here.
         return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
 
-    def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
+    def _make_sequence_with_max_throughput(
+        self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
+    ) -> List[RemoteSpanInfo]:
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
@@ -299,7 +301,13 @@ class RemoteSequenceManager:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
 
-            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
+            span_weights = np.array(
+                [
+                    span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
+                    for span in candidate_spans
+                ],
+                dtype=np.float64,
+            )
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end

+ 1 - 1
src/petals/server/from_pretrained.py

@@ -40,7 +40,7 @@ def load_pretrained_block(
     max_disk_space: Optional[int] = None,
 ) -> nn.Module:
     if config is None:
-        config = AutoDistributedConfig.from_pretrained(model_name, token=token)
+        config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
 

+ 1 - 1
src/petals/server/handler.py

@@ -347,7 +347,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         anext_task.cancel()
                         get_push_task.cancel()
                         return
-        except:
+        except Exception:
             logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
             raise
 

+ 2 - 2
src/petals/server/server.py

@@ -104,7 +104,7 @@ class Server:
 
         self.block_config = AutoDistributedConfig.from_pretrained(
             converted_model_name_or_path,
-            token=token,
+            use_auth_token=token,
             revision=revision,
         )
 
@@ -117,7 +117,7 @@ class Server:
         self.dht_prefix = dht_prefix
 
         if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+            expiration = max(3 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.expiration = expiration
 
         self.request_timeout = request_timeout

+ 32 - 33
src/petals/server/throughput.py

@@ -96,7 +96,7 @@ def get_server_throughput(
     throughput = throughput_info["forward_rps"] / average_blocks_used
     throughput = min(throughput, throughput_info.get("network_rps", math.inf))
     throughput_info["throughput"] = throughput
-    logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
+    logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")
 
     return throughput_info
 
@@ -109,13 +109,10 @@ def measure_throughput_info(
     quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
 ) -> Dict[str, float]:
-    """Measure network and compute throughput in forward pass tokens per second"""
-
     logger.info(
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
     )
-
-    throughput_info = {
+    return {
         "inference_rps": measure_compute_rps(
             config,
             device,
@@ -136,37 +133,39 @@ def measure_throughput_info(
             n_steps=10,
             inference=False,
         ),
+        "network_rps": measure_network_rps(config),
     }
-    try:
-        throughput_info["network_rps"] = measure_network_rps(config)
-    except Exception as e:
-        logger.info(f"Network throughput is not available: {e}")
-    return throughput_info
-
 
-def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
-    pipe_recv, pipe_send = mp.Pipe(duplex=False)
-    process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
-    process.start()
-
-    if not pipe_recv.poll(timeout):
-        process.terminate()
-        raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
-    network_info = pipe_recv.recv()
-    if "exception" in network_info:
-        raise RuntimeError(f"speedtest failed: {network_info['exception']}")
 
+def measure_network_rps(
+    config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6
+) -> Optional[float]:
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
-    network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
-    if network_rps == 0:
-        raise RuntimeError("speedtest has returned network_rps == 0")
-
-    logger.info(
-        f"Network throughput: {network_rps:.1f} RPS "
-        f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
-        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
-    )
-    return network_rps
+    try:
+        pipe_recv, pipe_send = mp.Pipe(duplex=False)
+        process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
+        process.start()
+
+        if not pipe_recv.poll(timeout):
+            process.terminate()
+            raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
+        network_info = pipe_recv.recv()
+        if "exception" in network_info:
+            raise RuntimeError(f"speedtest failed: {network_info['exception']}")
+
+        network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
+        if network_rps == 0:
+            raise RuntimeError("speedtest has returned network_rps == 0")
+
+        logger.info(
+            f"Network throughput: {network_rps:.1f} tokens/sec "
+            f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
+            f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
+        )
+        return network_rps
+    except RuntimeError as e:
+        logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
+        return default_speed / bits_per_request
 
 
 def _measure_bits_per_second(pipe_send: mp.Pipe):
@@ -215,7 +214,7 @@ def measure_compute_rps(
         devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
 
     logger.info(
-        f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block "
+        f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block "
         f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
     )
     return device_rps

+ 6 - 2
src/petals/utils/auto_config.py

@@ -31,8 +31,12 @@ class _AutoDistributedBase:
 
     @classmethod
     def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
-        if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
-            kwargs["token"] = True
+        if (
+            always_needs_auth(model_name_or_path)
+            and kwargs.get("token") is None
+            and kwargs.get("use_auth_token") is None
+        ):
+            kwargs["use_auth_token"] = True
 
         config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
         if config.model_type not in _CLASS_MAPPING:

+ 0 - 25
tests/scripts/remove_old_models.py

@@ -1,25 +0,0 @@
-import argparse
-from datetime import datetime
-
-from huggingface_hub import delete_repo, list_models
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
-    parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
-    parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    parser.add_argument("--dry_run", action="store_true")
-
-    args = parser.parse_args()
-
-    for model in list_models(author=args.author, full=True):
-        last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
-
-        if model.modelId.endswith("-main") or "/test-" not in model.modelId:
-            continue  # remove only test models
-
-        if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
-            if args.dry_run:
-                print(f"{model.modelId} can be deleted")
-            else:
-                delete_repo(repo_id=model.modelId, token=args.use_auth_token)