Forráskód Böngészése

Merge branch 'main' into forward_kwargs

justheuristic 1 éve
szülő
commit
ce89b649b5

+ 18 - 20
.github/workflows/run-tests.yaml

@@ -7,20 +7,21 @@ on:
 
 jobs:
   run-tests:
-    runs-on: ubuntu-latest
     strategy:
       matrix:
         include:
-          - { model: 'bigscience/bloom-560m', python-version: '3.8' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.9' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.10' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.11' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false
+    runs-on: ${{ matrix.os }}-latest
     timeout-minutes: 15
     steps:
       - name: Increase swap space
+        if: ${{ matrix.os == 'ubuntu' }}
         uses: pierotofy/set-swap-space@master
         with:
           swap-size-gb: 10
@@ -47,12 +48,7 @@ jobs:
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
-          # [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
-
-          bash -c 'while true; do free -h && sleep 30s; done' &
-          RAM_WATCH_PID=$!
-
-          # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
+          # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
           python -m petals.cli.run_dht \
             --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@@ -61,7 +57,7 @@ jobs:
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
 
-          sleep 5  # wait for DHT init
+          until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
           python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
             --mean_balance_check_period 10 \
@@ -95,11 +91,15 @@ jobs:
           sleep 30  # wait for servers to eval throughput, download layers, and rebalance
           kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init
 
-          # [Step 3] Run PyTest
+          # [Step 2] Run PyTest
+
+          # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+          export no_proxy=*
+          export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
 
           pytest tests --durations=0 --durations-min=1.0 -v
 
-          # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
+          # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
 
           python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3
@@ -110,9 +110,7 @@ jobs:
           python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
 
-          # [Step 5] Clean up
-
-          kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived tests
+          # [Step 4] Clean up
 
-          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
+          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           echo "Done!"

+ 43 - 53
README.md

@@ -8,20 +8,20 @@
     <br>
 </p>
 
-Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 ```python
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 
-model_name = "stabilityai/StableBeluga2"
-# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
-# repos with LLaMA-65B, "bigscience/bloom", or "bigscience/bloomz"
+# Choose any model available at https://health.petals.dev
+model_name = "petals-team/StableBeluga2"
 
+# Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
 model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
-# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
 
+# Run the model as if it were on your computer
 inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
 outputs = model.generate(inputs, max_new_tokens=5)
 print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
@@ -31,73 +31,58 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 
-🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
-
-📋 **Terms of use.** Make sure you follow the model license (see [LLaMA 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [LLaMA](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)).
+🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
-### Connect your GPU and increase Petals capacity
+## Connect your GPU and increase Petals capacity
 
-Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
+Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
 
-🐍 **Linux + Anaconda.** Run these commands:
+🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server stabilityai/StableBeluga2
+python -m petals.cli.run_server petals-team/StableBeluga2
 ```
 
-🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
+🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
 
-🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image:
+🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 ```bash
-sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
+sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
+    learningathome/petals:main \
+    python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
 ```
 
-These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
-
-🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`:
+🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
 
 ```bash
-python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
+brew install python
+python3 -m pip install git+https://github.com/bigscience-workshop/petals
+python3 -m petals.cli.run_server petals-team/StableBeluga2
 ```
 
-💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
-
-🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
-
-🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
-
-### Check out tutorials, examples, and more
-
-Basic tutorials:
-
-- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
-- Prompt-tune LLaMA-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
-- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
+<p align="center">
+    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
+</p>
 
-Useful tools and advanced guides:
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
 
-- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
-- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
-- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
-- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
+🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
 
-Learning more:
+🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
-- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions)
-- In-depth system description: [paper](https://arxiv.org/abs/2209.01188)
+🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
 
 ## How does it work?
 
-- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
-- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
+- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
+- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
 - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
 
 <p align="center">
@@ -105,23 +90,28 @@ Learning more:
 </p>
 
 <p align="center">
-    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
-    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
     📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
+    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
+    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
 </p>
 
-## Installation
+## 📚 Tutorials, examples, and more
 
-Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux:
+Basic tutorials:
 
-```bash
-conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install git+https://github.com/bigscience-workshop/petals
-```
+- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
+- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
+- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
+
+Useful tools:
+
+- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
+- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
 
-If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
+Advanced guides:
 
-See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client).
+- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
+- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
 
 ## Benchmarks
 

+ 4 - 3
setup.cfg

@@ -18,6 +18,7 @@ classifiers =
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.10
+    Programming Language :: Python :: 3.11
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -36,14 +37,14 @@ install_requires =
     accelerate>=0.22.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers>=4.31.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
+    transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind==1.1.9
+    hivemind @ git+https://github.com/learning-at-home/hivemind
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2
-    cpufeature>=0.2.0
+    cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     sentencepiece>=0.1.99
     peft>=0.5.0

+ 9 - 3
src/petals/__init__.py

@@ -1,7 +1,13 @@
 import os
+import platform
 
 os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
 
+if platform.system() == "Darwin":
+    # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+    os.environ.setdefault("no_proxy", "*")
+    os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
+
 import hivemind
 import transformers
 from packaging import version
@@ -11,13 +17,13 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "2.0.1.post2"
+__version__ = "2.1.0"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
+        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
 
 
 def _override_bfloat16_mode_default():

+ 17 - 9
src/petals/cli/run_server.py

@@ -1,8 +1,10 @@
 import argparse
+import logging
 
 import configargparse
+import torch
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.limits import increase_file_limit
+from hivemind.utils import limits
 from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 
@@ -96,9 +98,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--alloc_timeout', type=float, default=1,
-                        help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
-                             'before rejecting the request')
+    parser.add_argument('--max_alloc_timeout', type=float, default=600,
+                        help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
+                             " before rejecting the request")
     parser.add_argument('--revision', type=str, default=None,
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -127,9 +129,9 @@ def main():
     group.add_argument('--new_swarm', action='store_true',
                        help='Start a new private swarm (i.e., do not connect to any initial peers)')
 
-    parser.add_argument('--increase_file_limit', action='store_true',
-                        help='On *nix, this will increase the max number of processes '
-                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--increase_file_limit', type=int, default=4096,
+                        help='On *nix, increase the max number of files a server can open '
+                             'before hitting "Too many open files" (set to zero to keep the system limit)')
     parser.add_argument('--stats_report_interval', type=int, required=False,
                         help='Interval between two reports of batch processing performance statistics')
 
@@ -185,8 +187,10 @@ def main():
 
     args["startup_timeout"] = args.pop("daemon_startup_timeout")
 
-    if args.pop("increase_file_limit"):
-        increase_file_limit()
+    file_limit = args.pop("increase_file_limit")
+    if file_limit:
+        limits.logger.setLevel(logging.WARNING)
+        limits.increase_file_limit(file_limit, file_limit)
 
     compression_type = args.pop("compression").upper()
     compression = getattr(CompressionType, compression_type)
@@ -207,6 +211,10 @@ def main():
 
     validate_version()
 
+    if not torch.backends.openmp.is_available():
+        # Necessary to prevent the server from freezing after forks
+        torch.set_num_threads(1)
+
     server = Server(
         **args,
         host_maddrs=host_maddrs,

+ 1 - 1
src/petals/client/inference_session.py

@@ -343,7 +343,7 @@ class InferenceSession:
         n_prev_spans = len(self._server_sessions)
         update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
         if attempt_no >= 1:
-            logger.info(
+            logger.debug(
                 f"Due to a server failure, remote attention caches "
                 f"from block {block_idx} to {update_end} will be regenerated"
             )

+ 3 - 3
src/petals/client/remote_generation.py

@@ -69,6 +69,8 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
     ):
         self._fix_generate_kwargs(kwargs)
+        if inputs is None:
+            inputs = kwargs.pop("input_ids", None)
 
         if session is not None:
             # If a session specified explicitly, use it
@@ -125,7 +127,7 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         return result
 
     @staticmethod
-    def _fix_generate_kwargs(kwargs: dict) -> dict:
+    def _fix_generate_kwargs(kwargs: dict):
         # Suppress inappropriate "Both max_new_tokens and max_length" HF warning
         if "max_length" in kwargs and kwargs["max_length"] is None:
             del kwargs["max_length"]
@@ -135,8 +137,6 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         if isinstance(do_sample, int):
             kwargs["do_sample"] = bool(do_sample)
 
-        return kwargs
-
     @staticmethod
     def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
         return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

+ 13 - 0
src/petals/data_structures.py

@@ -20,6 +20,19 @@ class ServerState(Enum):
 RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
 
 
+@pydantic.dataclasses.dataclass
+class ModelInfo:
+    num_blocks: int
+    repository: Optional[str] = None
+
+    def to_dict(self) -> dict:
+        return dataclasses.asdict(self)
+
+    @classmethod
+    def from_dict(cls, source: dict):
+        return cls(**source)
+
+
 @pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState

+ 1 - 0
src/petals/models/bloom/config.py

@@ -30,5 +30,6 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
         if loading_from_repo and dht_prefix is None:
             # We need "-petals" for backward compatibility with Petals < 1.2.0
             dht_prefix = str(model_name_or_path) + "-petals"
+            dht_prefix = dht_prefix.replace(".", "-")
             logger.info(f"Using DHT prefix: {dht_prefix}")
         return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)

+ 1 - 0
src/petals/models/llama/config.py

@@ -35,6 +35,7 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)
             dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
+            dht_prefix = dht_prefix.replace(".", "-")
             if not dht_prefix.endswith("-hf"):
                 dht_prefix += "-hf"
             logger.info(f"Using DHT prefix: {dht_prefix}")

+ 3 - 3
src/petals/server/backend.py

@@ -16,7 +16,7 @@ from transformers import PretrainedConfig
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
-from petals.utils.misc import is_dummy
+from petals.utils.misc import get_size_in_bytes, is_dummy
 
 logger = get_logger(__name__)
 
@@ -72,7 +72,7 @@ class TransformerBackend(ModuleBackend):
         )
 
         self.dtype = backend_dtype
-        self.dtype_bytes = torch.finfo(self.dtype).bits // 8
+        self.dtype_bytes = get_size_in_bytes(self.dtype)
         self.shard_num_heads = []
         for shard in self.module.module_shards:
             for submodule in shard.modules():
@@ -92,7 +92,7 @@ class TransformerBackend(ModuleBackend):
 
         self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
-            self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
+            self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""

+ 3 - 2
src/petals/server/block_utils.py

@@ -5,6 +5,7 @@ from accelerate import init_empty_weights
 from transformers import PretrainedConfig
 
 from petals.utils.convert_block import QuantType
+from petals.utils.misc import get_size_in_bytes
 
 
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@@ -37,7 +38,7 @@ def get_block_size(
     if location == "memory":
         if quant_type == QuantType.NONE:
             dtype = resolve_block_dtype(config, dtype)
-            bytes_per_value = torch.finfo(dtype).bits // 8
+            bytes_per_value = get_size_in_bytes(dtype)
         elif quant_type == QuantType.INT8:
             bytes_per_value = 1
         elif quant_type == QuantType.NF4:
@@ -46,6 +47,6 @@ def get_block_size(
             raise ValueError(f"Unsupported quant_type={quant_type}")
     elif location == "disk":
         dtype = resolve_block_dtype(config, "auto")
-        bytes_per_value = torch.finfo(dtype).bits // 8
+        bytes_per_value = get_size_in_bytes(dtype)
 
     return round(n_params * bytes_per_value * (1 + eps))

+ 11 - 3
src/petals/server/handler.py

@@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
+                alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 args_structure = metadata.get("args_structure")
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
@@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                async with self._allocate_cache(
+                    requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
+                ) as cache_handles:
                     background_tasks = set()
                     async for output_tensors, can_push in iterate_rpc_inference(
                         requested_uids=requested_uids,
@@ -535,14 +538,19 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     @contextlib.asynccontextmanager
     async def _allocate_cache(
-        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+        self,
+        backends: Sequence[TransformerBackend],
+        *,
+        batch_size: int,
+        max_length: int,
+        timeout: Optional[float],
     ) -> Sequence[Sequence[Handle]]:
         """
         Allocate memory cache for all transformer blocks, return cache handle
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
         """
         descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
-        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
+        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
             yield nested_pack(handles, descriptors)
 
     def _log_request(

+ 71 - 21
src/petals/server/memory_cache.py

@@ -12,12 +12,13 @@ import os
 import time
 from typing import AsyncContextManager, Dict, Optional, Sequence
 
-import hivemind
+import async_timeout
 import torch
-from hivemind.utils import TensorDescriptor, get_logger
+from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
 
 from petals.data_structures import Handle
 from petals.utils.asyncio import shield_and_wait
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -25,11 +26,12 @@ logger = get_logger(__name__)
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
-    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
-        self.alloc_timeout = alloc_timeout
+        self.max_alloc_timeout = max_alloc_timeout
         self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self.runtime_pid = os.getpid()
@@ -46,6 +48,14 @@ class MemoryCache:
     def current_size_bytes(self, value: int):
         self._current_size.value = value
 
+    @property
+    def enqueued_size_bytes(self) -> int:
+        return self._enqueued_size.value
+
+    @enqueued_size_bytes.setter
+    def enqueued_size_bytes(self, value: int):
+        self._enqueued_size.value = value
+
     @property
     def bytes_left(self) -> int:
         return self.max_size_bytes - self.current_size_bytes
@@ -59,11 +69,14 @@ class MemoryCache:
         self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
+    async def allocate_cache(
+        self, *descriptors: TensorDescriptor, timeout: float
+    ) -> AsyncContextManager[Sequence[Handle]]:
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
         :param descriptors: one or more tensors tensor of this size, dtype, etc
+        :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
 
         :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
           if not, it will count maximum tensor allocation across devices for the purposes of size limit
@@ -73,6 +86,8 @@ class MemoryCache:
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
+        if self.max_alloc_timeout is not None:
+            timeout = min(timeout, self.max_alloc_timeout)
         max_alloc_size = self.get_allocation_size(*descriptors)
 
         gib = 1024**3
@@ -83,10 +98,10 @@ class MemoryCache:
             f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
         )
 
-        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
         try:
             handles = await shield_and_wait(alloc_task)
-            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
             yield handles
         finally:
             self._free(max_alloc_size, alloc_task)
@@ -96,28 +111,62 @@ class MemoryCache:
         """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
         alloc_size_by_device = {}
         for descr in descriptors:
-            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
             alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
         return max(alloc_size_by_device.values())
 
-    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
+    async def _schedule_alloc(
+        self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
+    ) -> Sequence[Handle]:
         """
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
         """
+        try:
+            async with self._wait_for_free_memory(alloc_size, timeout):
+                with self._lock_metadata:
+                    handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
+                    self.current_size_bytes += alloc_size
+                    self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
+                    self._pipe_send.send((handles, descriptors))
+                    return handles
+        except TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
 
+    @contextlib.asynccontextmanager
+    async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
+        start_time = time.perf_counter()
         loop = asyncio.get_event_loop()
-        async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
-            if self.current_size_bytes + alloc_size > self.max_size_bytes:
-                await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
-            with self._lock_metadata:
-                handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
-                self.current_size_bytes += alloc_size
-                self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
-                self._pipe_send.send((handles, descriptors))
-                return handles
-
-    def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
+
+        with self._enqueued_size.get_lock():
+            self._enqueued_size.value += alloc_size
+        allocated = False
+        try:
+            context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
+            # contextlib.AsyncExitStack() is used as a null context here
+            async with context_manager:
+                if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
+                    raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                async with enter_asynchronously(self._lock_acquire_memory):
+                    if self.current_size_bytes + alloc_size > self.max_size_bytes:
+                        if timeout == 0:
+                            raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                        elapsed_time = time.perf_counter() - start_time
+                        remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
+                        await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
+
+                allocated = True
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
+                yield
+        except asyncio.TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
+        finally:
+            if not allocated:
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
+
+    def _free(self, alloc_size: int, alloc_task: asyncio.Task):
         if alloc_task.exception() is not None:
             return
         handles = alloc_task.result()
@@ -133,9 +182,10 @@ class MemoryCache:
             raise AllocationFailed(
                 f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
             )
+        timeout = timeout if timeout != float("inf") else None
         deadline = None if timeout is None else time.perf_counter() + timeout
         while self.current_size_bytes + allocated_size > self.max_size_bytes:
-            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            remaining_time = None if timeout is None else deadline - time.perf_counter()
             if not self._memory_freed_event.wait(remaining_time):
                 raise AllocationFailed(
                     f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"

+ 1 - 1
src/petals/server/reachability.py

@@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
                 protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
 
                 ready.set_result(True)
-                logger.info("Reachability service started")
+                logger.debug("Reachability service started")
 
                 async with protocol.serve(common_p2p):
                     await protocol._stop.wait()

+ 54 - 16
src/petals/server/server.py

@@ -3,13 +3,16 @@ from __future__ import annotations
 import gc
 import math
 import multiprocessing as mp
+import os
 import random
 import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
 
 import hivemind
+import psutil
 import torch
+import torch.mps
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
@@ -19,7 +22,7 @@ from transformers import PretrainedConfig
 
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -31,6 +34,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
+from petals.utils.misc import get_size_in_bytes
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
@@ -59,12 +63,12 @@ class Server:
         min_batch_size: int = 1,
         max_batch_size: Optional[int] = None,
         max_chunk_size_bytes: int = 256 * 1024 * 1024,
+        max_alloc_timeout: float = 600,
         attn_cache_tokens: Optional[int] = None,
         torch_dtype: str = "auto",
         revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
-        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
@@ -153,13 +157,25 @@ class Server:
         self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
 
         if device is None:
-            device = "cuda" if torch.cuda.is_available() else "cpu"
+            if torch.cuda.is_available():
+                device = "cuda"
+            elif torch.backends.mps.is_available():
+                device = "mps"
+            else:
+                device = "cpu"
         device = torch.device(device)
         if device.type == "cuda" and device.index is None:
             device = torch.device(device.type, index=0)
         self.device = device
 
         torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
+        if device.type == "cpu" and torch_dtype == torch.float16:
+            raise ValueError(
+                f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
+            )
+        if device.type == "mps" and torch_dtype == torch.bfloat16:
+            logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
+            torch_dtype = torch.float16
         self.torch_dtype = torch_dtype
 
         if tensor_parallel_devices is None:
@@ -185,13 +201,14 @@ class Server:
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.max_chunk_size_bytes = max_chunk_size_bytes
+        self.max_alloc_timeout = max_alloc_timeout
 
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
             attn_cache_tokens = 32768 if is_multiquery_attn else 8192
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
-        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+        self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
 
         # For disk cache
         self.cache_dir = cache_dir
@@ -217,8 +234,6 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        self.alloc_timeout = alloc_timeout
-
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput_info = get_server_throughput(
@@ -245,21 +260,26 @@ class Server:
             using_relay=reachable_via_relay,
             **throughput_info,
         )
+        self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
+        if not os.path.isdir(converted_model_name_or_path):
+            self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
 
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
 
+        self.module_container = None
         self.stop = threading.Event()
 
     def _choose_num_blocks(self) -> int:
-        assert self.device.type == "cuda", (
+        assert self.device.type in ("cuda", "mps"), (
             "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
             "CPU-only servers in the public swarm are discouraged since they are much slower"
         )
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
 
         if num_devices > 1:
+            assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
             memory_per_device = tuple(
                 torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
             )
@@ -270,8 +290,10 @@ class Server:
                     "Please launch individual servers on each GPU or set --num_blocks manually to "
                     "override this exception."
                 )
-        else:
+        elif self.device.type == "cuda":
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
+        else:
+            total_memory = psutil.virtual_memory().total
 
         gib = 1024**3
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@@ -311,13 +333,14 @@ class Server:
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
-                alloc_timeout=self.alloc_timeout,
                 server_info=self.server_info,
+                model_info=self.model_info,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
                 max_batch_size=self.max_batch_size,
                 max_chunk_size_bytes=self.max_chunk_size_bytes,
+                max_alloc_timeout=self.max_alloc_timeout,
                 inference_max_length=self.inference_max_length,
                 torch_dtype=self.torch_dtype,
                 cache_dir=self.cache_dir,
@@ -360,7 +383,7 @@ class Server:
             self._clean_memory_and_fds()
 
     def _clean_memory_and_fds(self):
-        del self.module_container
+        self.module_container = None
         gc.collect()  # In particular, this closes unused file descriptors
 
         if self.device.type == "cuda":
@@ -373,6 +396,8 @@ class Server:
                 f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
                 f"{reserved_vram / gib:.1f} GiB reserved memory"
             )
+        elif self.device.type == "mps":
+            torch.mps.empty_cache()
 
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
@@ -391,8 +416,10 @@ class Server:
         module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
         return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
 
-    def shutdown(self):
+    def shutdown(self, timeout: Optional[float] = 5):
         self.stop.set()
+        if self.module_container is not None and self.module_container.is_alive():
+            self.module_container.join(timeout)
 
         if self.reachability_protocol is not None:
             self.reachability_protocol.shutdown()
@@ -413,12 +440,13 @@ class ModuleContainer(threading.Thread):
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
-        alloc_timeout: float,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         block_indices: List[int],
         min_batch_size: int,
         max_batch_size: int,
         max_chunk_size_bytes: int,
+        max_alloc_timeout: float,
         torch_dtype: torch.dtype,
         cache_dir: str,
         max_disk_space: int,
@@ -434,13 +462,14 @@ class ModuleContainer(threading.Thread):
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
 
         server_info.state = ServerState.JOINING
         dht_announcer = ModuleAnnouncerThread(
             module_uids,
             dht,
             server_info,
+            model_info,
             block_config=block_config,
             memory_cache=memory_cache,
             update_period=update_period,
@@ -649,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         dht: DHT,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         *,
         block_config: PretrainedConfig,
         memory_cache: MemoryCache,
@@ -661,9 +691,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.dht = dht
         self.server_info = server_info
+        self.model_info = model_info
         self.memory_cache = memory_cache
 
-        self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
+        self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
         self.bytes_per_token //= block_config.num_key_value_groups
 
         self.update_period = update_period
@@ -671,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.trigger = threading.Event()
 
         self.max_pinged = max_pinged
-        dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
+        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
         block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
         start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
         self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
@@ -698,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
             )
             if self.server_info.state == ServerState.OFFLINE:
                 break
+            if not self.dht_prefix.startswith("_"):  # Not private
+                self.dht.store(
+                    key="_petals.models",
+                    subkey=self.dht_prefix,
+                    value=self.model_info.to_dict(),
+                    expiration_time=get_dht_time() + self.expiration,
+                )
 
             delay = self.update_period - (time.perf_counter() - start_time)
             if delay < 0:

+ 14 - 28
src/petals/server/task_pool.py

@@ -32,7 +32,7 @@ class Task:
         return self.future._uid
 
 
-class PrioritizedTaskPool(TaskPoolBase):
+class PrioritizedTaskPool(threading.Thread):
     """
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@@ -62,52 +62,41 @@ class PrioritizedTaskPool(TaskPoolBase):
         daemon=True,
         start=False,
     ):
-        super().__init__(process_func, daemon=daemon, name=name)
+        super().__init__(daemon=daemon, name=name)
+        self.process_func = process_func
+        # the lower the priority is, the more urgent it is to process this pool
+        self._priority = mp.Value(ctypes.c_double, 1.0)
+
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.device = device
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
 
-        self._prioritizer_thread = threading.Thread(
-            name=self.name + "_prioritizer",
-            target=self._prioritize_tasks,
-            args=[self.submitted_tasks, self._ordered_tasks],
-            daemon=True,
-        )
         self._dispatched_tasks = {}
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
 
-        self._stop = mp.Event()
         if start:
             self.start()
 
-    @staticmethod
-    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+    def run(self):
         """Read tasks from incoming queue and put them into a local priority queue"""
         while True:
-            task = submitted_tasks.get()
+            task = self.submitted_tasks.get()
             if task is None:
                 logger.debug("Shutting down prioritizer thread")
                 break
 
-            ordered_tasks.put(task, block=True)
-
-    def start(self):
-        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
-        self._prioritizer_thread.start()
-        super().start()
+            self._ordered_tasks.put(task, block=True)
 
-    def shutdown(self, timeout: float = 3):
-        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
-        self._stop.set()
+    def terminate(self):
+        """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
+        self.shutdown()
 
-        self.join(timeout)
-        if self.is_alive():
-            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
-            self.terminate()
+    def shutdown(self):
+        self.submitted_tasks.put(None)  # Shuts down self.run()
 
     def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
@@ -161,9 +150,6 @@ class PrioritizedTaskPool(TaskPoolBase):
         else:
             task.future.set_exception(exception)
 
-    def run(self, *args, **kwargs):
-        self._stop.wait()
-
     @property
     def empty(self):
         return not self.batch_receiver.poll()

+ 12 - 6
src/petals/server/throughput.py

@@ -9,6 +9,7 @@ from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 
 import torch
+import torch.mps
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
@@ -207,14 +208,12 @@ def measure_compute_rps(
         elapsed = 0
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
 
         start_time = time.perf_counter()
-        for step in range(n_steps):
+        for _ in range(n_steps):
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
         elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
 
@@ -230,8 +229,15 @@ def measure_compute_rps(
     return device_rps
 
 
+def synchronize(device: torch.device):
+    if device.type == "cuda":
+        torch.cuda.synchronize(device)
+    elif device.type == "mps":
+        torch.mps.synchronize()
+
+
 def get_device_name(device: torch.device) -> str:
-    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
+    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
 
 
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

+ 10 - 0
src/petals/utils/misc.py

@@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool:
     return tensor.numel() == 0
 
 
+SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
+
+
+def get_size_in_bytes(dtype: torch.dtype) -> int:
+    if dtype in SPECIAL_DTYPE_SIZES:
+        return SPECIAL_DTYPE_SIZES[dtype]
+    get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
+    return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
+
+
 def docstring_from(source):
     def add_docstring(dest):
         dest.__doc__ = source.__doc__

+ 2 - 1
src/petals/utils/peft.py

@@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block(
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
             )
         adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
-    bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
+    bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
     return adapter_parameters * bytes_per_parameter

+ 184 - 0
tests/test_cache.py

@@ -0,0 +1,184 @@
+import asyncio
+import multiprocessing as mp
+import random
+import time
+from typing import Optional
+
+import pytest
+import pytest_asyncio  # make sure the module exists; otherwise the test will be skipped
+import torch
+from hivemind import TensorDescriptor
+
+from petals.server.memory_cache import AllocationFailed, MemoryCache
+from petals.utils.misc import get_size_in_bytes
+
+
+def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
+    if dtype is None:
+        dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
+    elem_size_bytes = get_size_in_bytes(dtype)
+    descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
+    return descr
+
+
+@pytest.mark.asyncio
+async def test_cache_timeout():
+    cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
+    cache.runtime_pid += 1  # pretend we're another process
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
+        pass
+
+    async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+            async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
+                        pass
+                assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
+                async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
+                    pass
+
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0):  # exceeds max timeout
+                        pass
+                assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
+
+            # test memory allocation when another task frees the memory
+            async def _klog_the_cache():
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+                    pass
+
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):  # exceeds max timeout
+                pass  # this memory should allocate once the background task clears the queue
+            assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+            # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            with pytest.raises(AllocationFailed):
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+                    pass  # this memory should allocate once the background task clears the queue
+            assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+
+@pytest.mark.asyncio
+async def test_unlimited_timeout():
+    cache = MemoryCache(max_size_bytes=1024)
+    cache.runtime_pid += 1  # pretend we're another process
+    t_start = time.perf_counter()
+
+    async def _klog_the_cache():
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+            await asyncio.sleep(0.5)
+
+    alloc_task = asyncio.create_task(_klog_the_cache())
+    await asyncio.sleep(0.1)
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
+        await alloc_task
+    assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
+
+
+@pytest.mark.asyncio
+async def test_cache_usage():
+    cache = MemoryCache(max_size_bytes=2048)
+    alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
+    pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
+    with pytest.raises(AssertionError):
+        async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
+            pass  # fails because cache must be allocated from another process
+
+    descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8))  # 768 bytes
+    descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64))  # 8 bytes
+    descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool))  # 33 bytes
+    descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64))  # 0 bytes
+    descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16))  # 1536 bytes
+    descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8))  # 1792 bytes
+
+    async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
+        loop = asyncio.get_event_loop()
+        async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
+            pipe_sender.send(handles)
+            await loop.run_in_executor(None, dealloc_event.wait)
+
+    async def _allocate_af():
+        alloc_event.wait()
+        allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
+        await allocate_a_task
+        allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache
+        await allocate_f_task
+
+    alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
+    alloc_process1.start()
+
+    async def _allocate_bcde():
+        alloc_event.wait()
+        await asyncio.sleep(0.1)  # ensure that the other tensor is always allocated (and sent through pipe) first
+        allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
+        allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
+        await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
+
+    alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
+    alloc_process2.start()
+    assert cache.current_size_bytes == 0
+    alloc_event.set()
+    (handle_a,) = pipe_receiver.recv()
+
+    handle_b, handle_c, handle_d = pipe_receiver.recv()
+
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tensor_a.dtype == torch.uint8
+        tensor_a[2:5] = torch.tensor((42, 43, 44))
+
+    with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
+        assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
+        assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
+        tensor_a += 1
+        tensor_b[...] = -1.337
+    assert cache.current_size_bytes == 809  # this checks a,b,c,d are allocated but b still awaits memory
+
+    dealloc_bcd_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 768  # only tensor a should be allocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a, handle_b):
+            pass  # one of handles (c) is deallocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_d):
+            pass  # handle_d is deallocated correctly, even though it is never used
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tuple(tensor_a[2:5]) == (43, 44, 45)
+
+    dealloc_a_event.set()
+    (handle_e,) = pipe_receiver.recv()  # e can finally be allocated
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1536  # tensor e should finally be able to allocate
+
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a):
+            pass  # tensor a is no longer allocated
+    with cache.use_cache(handle_e) as (tensor_e,):
+        assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
+
+    dealloc_e_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1792  # only tensor f is still allocated
+    dealloc_f_event.set()
+
+    alloc_process1.join()
+    alloc_process2.join()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 0
+    assert cache.current_size_bytes == 0
+    assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
+    assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"

+ 20 - 0
tests/test_full_model.py

@@ -149,3 +149,23 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n
     outputs = make_generate_calls(model, inputs, **options)
     ref_outputs = ref_model.generate(inputs, **options)
     assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
+
+
+@pytest.mark.forked
+def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
+    assert inputs.keys() == {"input_ids", "attention_mask"}
+
+    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
+    ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
+    assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
+
+    with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
+        outputs = torch.cat(
+            [
+                model.generate(**inputs, max_new_tokens=2),
+                model.generate(None, max_new_tokens=max_new_tokens - 2),
+            ],
+            dim=1,
+        )
+    assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"

+ 29 - 18
tests/test_priority_pool.py

@@ -1,4 +1,5 @@
 import multiprocessing as mp
+import platform
 import time
 
 import pytest
@@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime
 from petals.server.task_pool import PrioritizedTaskPool
 
 
+def _submit_tasks(runtime_ready, pools, results_valid):
+    runtime_ready.wait()
+
+    futures = []
+    futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+    time.sleep(0.01)
+    futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+    futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+    futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+    futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+    futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+    futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+    for i, f in enumerate(futures):
+        assert f.result()[0].item() == i**2
+    results_valid.set()
+
+
+@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
 @pytest.mark.forked
 def test_priority_pools():
     outputs_queue = mp.SimpleQueue()
+    runtime_ready = mp.Event()
     results_valid = mp.Event()
 
     def dummy_pool_func(args, kwargs):
@@ -32,27 +54,14 @@ def test_priority_pools():
         PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
     )
 
+    # Simulate requests coming from ConnectionHandlers
+    proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
+    proc.start()
+
     runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime.ready = runtime_ready
     runtime.start()
 
-    def process_tasks():
-        futures = []
-        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
-        time.sleep(0.01)
-        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
-        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
-        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
-        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
-        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
-        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
-        for i, f in enumerate(futures):
-            assert f.result()[0].item() == i**2
-        results_valid.set()
-
-    proc = mp.Process(target=process_tasks)
-    proc.start()
     proc.join()
     assert results_valid.is_set()
 
@@ -70,3 +79,5 @@ def test_priority_pools():
     #                                            3 - task with priority 2 from pool A
     #                                               4 - task with priority 10 from pool A
     #                                                  7 - task with priority 11 from pool B
+
+    runtime.shutdown()

+ 1 - 1
tests/test_remote_sequential.py

@@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
 
     (outputs_ref * output_proj).sum().backward()
     assert input_prompts_ref.grad is not None
-    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
     assert intermediate_prompts_ref.grad is not None
     assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)