5
0
Эх сурвалжийг харах

Merge branch 'main' into forward_kwargs

justheuristic 1 жил өмнө
parent
commit
ce89b649b5

+ 18 - 20
.github/workflows/run-tests.yaml

@@ -7,20 +7,21 @@ on:
 
 
 jobs:
 jobs:
   run-tests:
   run-tests:
-    runs-on: ubuntu-latest
     strategy:
     strategy:
       matrix:
       matrix:
         include:
         include:
-          - { model: 'bigscience/bloom-560m', python-version: '3.8' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.9' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.10' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.11' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false
       fail-fast: false
+    runs-on: ${{ matrix.os }}-latest
     timeout-minutes: 15
     timeout-minutes: 15
     steps:
     steps:
       - name: Increase swap space
       - name: Increase swap space
+        if: ${{ matrix.os == 'ubuntu' }}
         uses: pierotofy/set-swap-space@master
         uses: pierotofy/set-swap-space@master
         with:
         with:
           swap-size-gb: 10
           swap-size-gb: 10
@@ -47,12 +48,7 @@ jobs:
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
           export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
 
-          # [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
-
-          bash -c 'while true; do free -h && sleep 30s; done' &
-          RAM_WATCH_PID=$!
-
-          # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
+          # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
 
           python -m petals.cli.run_dht \
           python -m petals.cli.run_dht \
             --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
             --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@@ -61,7 +57,7 @@ jobs:
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
           # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
 
 
-          sleep 5  # wait for DHT init
+          until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
 
           python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
           python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
             --mean_balance_check_period 10 \
             --mean_balance_check_period 10 \
@@ -95,11 +91,15 @@ jobs:
           sleep 30  # wait for servers to eval throughput, download layers, and rebalance
           sleep 30  # wait for servers to eval throughput, download layers, and rebalance
           kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init
           kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init
 
 
-          # [Step 3] Run PyTest
+          # [Step 2] Run PyTest
+
+          # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+          export no_proxy=*
+          export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
 
 
           pytest tests --durations=0 --durations-min=1.0 -v
           pytest tests --durations=0 --durations-min=1.0 -v
 
 
-          # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
+          # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
 
 
           python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
           python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3
             --seq_len 3
@@ -110,9 +110,7 @@ jobs:
           python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
           python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
             --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
 
 
-          # [Step 5] Clean up
-
-          kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived tests
+          # [Step 4] Clean up
 
 
-          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
+          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           echo "Done!"
           echo "Done!"

+ 43 - 53
README.md

@@ -8,20 +8,20 @@
     <br>
     <br>
 </p>
 </p>
 
 
-Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 
 ```python
 ```python
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 from petals import AutoDistributedModelForCausalLM
 
 
-model_name = "stabilityai/StableBeluga2"
-# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
-# repos with LLaMA-65B, "bigscience/bloom", or "bigscience/bloomz"
+# Choose any model available at https://health.petals.dev
+model_name = "petals-team/StableBeluga2"
 
 
+# Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
 tokenizer = AutoTokenizer.from_pretrained(model_name)
 model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
 model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
-# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
 
 
+# Run the model as if it were on your computer
 inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
 inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
 outputs = model.generate(inputs, max_new_tokens=5)
 outputs = model.generate(inputs, max_new_tokens=5)
 print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
 print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
@@ -31,73 +31,58 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 </p>
 
 
-🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
-
-📋 **Terms of use.** Make sure you follow the model license (see [LLaMA 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [LLaMA](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)).
+🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
 
 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
 
-### Connect your GPU and increase Petals capacity
+## Connect your GPU and increase Petals capacity
 
 
-Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
+Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
 
 
-🐍 **Linux + Anaconda.** Run these commands:
+🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 
 ```bash
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server stabilityai/StableBeluga2
+python -m petals.cli.run_server petals-team/StableBeluga2
 ```
 ```
 
 
-🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
+🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
 
 
-🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image:
+🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 
 ```bash
 ```bash
-sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
+sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
+    learningathome/petals:main \
+    python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
 ```
 ```
 
 
-These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
-
-🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`:
+🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
 
 
 ```bash
 ```bash
-python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
+brew install python
+python3 -m pip install git+https://github.com/bigscience-workshop/petals
+python3 -m petals.cli.run_server petals-team/StableBeluga2
 ```
 ```
 
 
-💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
-
-🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
-
-🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
-
-### Check out tutorials, examples, and more
-
-Basic tutorials:
-
-- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
-- Prompt-tune LLaMA-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
-- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
+<p align="center">
+    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
+</p>
 
 
-Useful tools and advanced guides:
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
 
 
-- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
-- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
-- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
-- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
+🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
 
 
-Learning more:
+🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
 
-- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions)
-- In-depth system description: [paper](https://arxiv.org/abs/2209.01188)
+🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
 
 
 ## How does it work?
 ## How does it work?
 
 
-- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
-- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
+- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
+- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
 - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
 - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
 
 
 <p align="center">
 <p align="center">
@@ -105,23 +90,28 @@ Learning more:
 </p>
 </p>
 
 
 <p align="center">
 <p align="center">
-    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
-    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
     📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
     📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
+    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
+    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
 </p>
 </p>
 
 
-## Installation
+## 📚 Tutorials, examples, and more
 
 
-Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux:
+Basic tutorials:
 
 
-```bash
-conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install git+https://github.com/bigscience-workshop/petals
-```
+- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
+- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
+- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
+
+Useful tools:
+
+- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
+- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
 
 
-If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
+Advanced guides:
 
 
-See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client).
+- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
+- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
 
 
 ## Benchmarks
 ## Benchmarks
 
 

+ 4 - 3
setup.cfg

@@ -18,6 +18,7 @@ classifiers =
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: 3.10
+    Programming Language :: Python :: 3.11
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Artificial Intelligence
     Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -36,14 +37,14 @@ install_requires =
     accelerate>=0.22.0
     accelerate>=0.22.0
     huggingface-hub>=0.11.1,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
     tokenizers>=0.13.3
-    transformers>=4.31.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
+    transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind==1.1.9
+    hivemind @ git+https://github.com/learning-at-home/hivemind
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
     humanfriendly
     humanfriendly
     async-timeout>=4.0.2
     async-timeout>=4.0.2
-    cpufeature>=0.2.0
+    cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     packaging>=20.9
     sentencepiece>=0.1.99
     sentencepiece>=0.1.99
     peft>=0.5.0
     peft>=0.5.0

+ 9 - 3
src/petals/__init__.py

@@ -1,7 +1,13 @@
 import os
 import os
+import platform
 
 
 os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
 os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
 
 
+if platform.system() == "Darwin":
+    # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+    os.environ.setdefault("no_proxy", "*")
+    os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
+
 import hivemind
 import hivemind
 import transformers
 import transformers
 from packaging import version
 from packaging import version
@@ -11,13 +17,13 @@ from petals.models import *
 from petals.utils import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "2.0.1.post2"
+__version__ = "2.1.0"
 
 
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
     assert (
-        version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
+        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
 
 
 
 
 def _override_bfloat16_mode_default():
 def _override_bfloat16_mode_default():

+ 17 - 9
src/petals/cli/run_server.py

@@ -1,8 +1,10 @@
 import argparse
 import argparse
+import logging
 
 
 import configargparse
 import configargparse
+import torch
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.limits import increase_file_limit
+from hivemind.utils import limits
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 from humanfriendly import parse_size
 
 
@@ -96,9 +98,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--alloc_timeout', type=float, default=1,
-                        help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
-                             'before rejecting the request')
+    parser.add_argument('--max_alloc_timeout', type=float, default=600,
+                        help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
+                             " before rejecting the request")
     parser.add_argument('--revision', type=str, default=None,
     parser.add_argument('--revision', type=str, default=None,
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -127,9 +129,9 @@ def main():
     group.add_argument('--new_swarm', action='store_true',
     group.add_argument('--new_swarm', action='store_true',
                        help='Start a new private swarm (i.e., do not connect to any initial peers)')
                        help='Start a new private swarm (i.e., do not connect to any initial peers)')
 
 
-    parser.add_argument('--increase_file_limit', action='store_true',
-                        help='On *nix, this will increase the max number of processes '
-                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--increase_file_limit', type=int, default=4096,
+                        help='On *nix, increase the max number of files a server can open '
+                             'before hitting "Too many open files" (set to zero to keep the system limit)')
     parser.add_argument('--stats_report_interval', type=int, required=False,
     parser.add_argument('--stats_report_interval', type=int, required=False,
                         help='Interval between two reports of batch processing performance statistics')
                         help='Interval between two reports of batch processing performance statistics')
 
 
@@ -185,8 +187,10 @@ def main():
 
 
     args["startup_timeout"] = args.pop("daemon_startup_timeout")
     args["startup_timeout"] = args.pop("daemon_startup_timeout")
 
 
-    if args.pop("increase_file_limit"):
-        increase_file_limit()
+    file_limit = args.pop("increase_file_limit")
+    if file_limit:
+        limits.logger.setLevel(logging.WARNING)
+        limits.increase_file_limit(file_limit, file_limit)
 
 
     compression_type = args.pop("compression").upper()
     compression_type = args.pop("compression").upper()
     compression = getattr(CompressionType, compression_type)
     compression = getattr(CompressionType, compression_type)
@@ -207,6 +211,10 @@ def main():
 
 
     validate_version()
     validate_version()
 
 
+    if not torch.backends.openmp.is_available():
+        # Necessary to prevent the server from freezing after forks
+        torch.set_num_threads(1)
+
     server = Server(
     server = Server(
         **args,
         **args,
         host_maddrs=host_maddrs,
         host_maddrs=host_maddrs,

+ 1 - 1
src/petals/client/inference_session.py

@@ -343,7 +343,7 @@ class InferenceSession:
         n_prev_spans = len(self._server_sessions)
         n_prev_spans = len(self._server_sessions)
         update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
         update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
         if attempt_no >= 1:
         if attempt_no >= 1:
-            logger.info(
+            logger.debug(
                 f"Due to a server failure, remote attention caches "
                 f"Due to a server failure, remote attention caches "
                 f"from block {block_idx} to {update_end} will be regenerated"
                 f"from block {block_idx} to {update_end} will be regenerated"
             )
             )

+ 3 - 3
src/petals/client/remote_generation.py

@@ -69,6 +69,8 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
         self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
     ):
     ):
         self._fix_generate_kwargs(kwargs)
         self._fix_generate_kwargs(kwargs)
+        if inputs is None:
+            inputs = kwargs.pop("input_ids", None)
 
 
         if session is not None:
         if session is not None:
             # If a session specified explicitly, use it
             # If a session specified explicitly, use it
@@ -125,7 +127,7 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         return result
         return result
 
 
     @staticmethod
     @staticmethod
-    def _fix_generate_kwargs(kwargs: dict) -> dict:
+    def _fix_generate_kwargs(kwargs: dict):
         # Suppress inappropriate "Both max_new_tokens and max_length" HF warning
         # Suppress inappropriate "Both max_new_tokens and max_length" HF warning
         if "max_length" in kwargs and kwargs["max_length"] is None:
         if "max_length" in kwargs and kwargs["max_length"] is None:
             del kwargs["max_length"]
             del kwargs["max_length"]
@@ -135,8 +137,6 @@ class RemoteGenerationMixin(_SkipTokensMixin):
         if isinstance(do_sample, int):
         if isinstance(do_sample, int):
             kwargs["do_sample"] = bool(do_sample)
             kwargs["do_sample"] = bool(do_sample)
 
 
-        return kwargs
-
     @staticmethod
     @staticmethod
     def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
     def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
         return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
         return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

+ 13 - 0
src/petals/data_structures.py

@@ -20,6 +20,19 @@ class ServerState(Enum):
 RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
 RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
 
 
 
 
+@pydantic.dataclasses.dataclass
+class ModelInfo:
+    num_blocks: int
+    repository: Optional[str] = None
+
+    def to_dict(self) -> dict:
+        return dataclasses.asdict(self)
+
+    @classmethod
+    def from_dict(cls, source: dict):
+        return cls(**source)
+
+
 @pydantic.dataclasses.dataclass
 @pydantic.dataclasses.dataclass
 class ServerInfo:
 class ServerInfo:
     state: ServerState
     state: ServerState

+ 1 - 0
src/petals/models/bloom/config.py

@@ -30,5 +30,6 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
         if loading_from_repo and dht_prefix is None:
         if loading_from_repo and dht_prefix is None:
             # We need "-petals" for backward compatibility with Petals < 1.2.0
             # We need "-petals" for backward compatibility with Petals < 1.2.0
             dht_prefix = str(model_name_or_path) + "-petals"
             dht_prefix = str(model_name_or_path) + "-petals"
+            dht_prefix = dht_prefix.replace(".", "-")
             logger.info(f"Using DHT prefix: {dht_prefix}")
             logger.info(f"Using DHT prefix: {dht_prefix}")
         return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
         return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)

+ 1 - 0
src/petals/models/llama/config.py

@@ -35,6 +35,7 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         if loading_from_repo and dht_prefix is None:
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)
             dht_prefix = str(model_name_or_path)
             dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
             dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
+            dht_prefix = dht_prefix.replace(".", "-")
             if not dht_prefix.endswith("-hf"):
             if not dht_prefix.endswith("-hf"):
                 dht_prefix += "-hf"
                 dht_prefix += "-hf"
             logger.info(f"Using DHT prefix: {dht_prefix}")
             logger.info(f"Using DHT prefix: {dht_prefix}")

+ 3 - 3
src/petals/server/backend.py

@@ -16,7 +16,7 @@ from transformers import PretrainedConfig
 from petals.data_structures import InferenceMetadata
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
-from petals.utils.misc import is_dummy
+from petals.utils.misc import get_size_in_bytes, is_dummy
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -72,7 +72,7 @@ class TransformerBackend(ModuleBackend):
         )
         )
 
 
         self.dtype = backend_dtype
         self.dtype = backend_dtype
-        self.dtype_bytes = torch.finfo(self.dtype).bits // 8
+        self.dtype_bytes = get_size_in_bytes(self.dtype)
         self.shard_num_heads = []
         self.shard_num_heads = []
         for shard in self.module.module_shards:
         for shard in self.module.module_shards:
             for submodule in shard.modules():
             for submodule in shard.modules():
@@ -92,7 +92,7 @@ class TransformerBackend(ModuleBackend):
 
 
         self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
         self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
-            self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
+            self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
 
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""
         """Create tensor descriptors for attention cache tensors used during inference_step"""

+ 3 - 2
src/petals/server/block_utils.py

@@ -5,6 +5,7 @@ from accelerate import init_empty_weights
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
 from petals.utils.convert_block import QuantType
 from petals.utils.convert_block import QuantType
+from petals.utils.misc import get_size_in_bytes
 
 
 
 
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@@ -37,7 +38,7 @@ def get_block_size(
     if location == "memory":
     if location == "memory":
         if quant_type == QuantType.NONE:
         if quant_type == QuantType.NONE:
             dtype = resolve_block_dtype(config, dtype)
             dtype = resolve_block_dtype(config, dtype)
-            bytes_per_value = torch.finfo(dtype).bits // 8
+            bytes_per_value = get_size_in_bytes(dtype)
         elif quant_type == QuantType.INT8:
         elif quant_type == QuantType.INT8:
             bytes_per_value = 1
             bytes_per_value = 1
         elif quant_type == QuantType.NF4:
         elif quant_type == QuantType.NF4:
@@ -46,6 +47,6 @@ def get_block_size(
             raise ValueError(f"Unsupported quant_type={quant_type}")
             raise ValueError(f"Unsupported quant_type={quant_type}")
     elif location == "disk":
     elif location == "disk":
         dtype = resolve_block_dtype(config, "auto")
         dtype = resolve_block_dtype(config, "auto")
-        bytes_per_value = torch.finfo(dtype).bits // 8
+        bytes_per_value = get_size_in_bytes(dtype)
 
 
     return round(n_params * bytes_per_value * (1 + eps))
     return round(n_params * bytes_per_value * (1 + eps))

+ 11 - 3
src/petals/server/handler.py

@@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 max_length = metadata.get("max_length")
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
                 session_id = metadata.get("session_id")
+                alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 args_structure = metadata.get("args_structure")
                 args_structure = metadata.get("args_structure")
                 if not requested_uids:
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                     raise ValueError("User must specify at least one block for inference, but got none")
@@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
 
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                async with self._allocate_cache(
+                    requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
+                ) as cache_handles:
                     background_tasks = set()
                     background_tasks = set()
                     async for output_tensors, can_push in iterate_rpc_inference(
                     async for output_tensors, can_push in iterate_rpc_inference(
                         requested_uids=requested_uids,
                         requested_uids=requested_uids,
@@ -535,14 +538,19 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def _allocate_cache(
     async def _allocate_cache(
-        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+        self,
+        backends: Sequence[TransformerBackend],
+        *,
+        batch_size: int,
+        max_length: int,
+        timeout: Optional[float],
     ) -> Sequence[Sequence[Handle]]:
     ) -> Sequence[Sequence[Handle]]:
         """
         """
         Allocate memory cache for all transformer blocks, return cache handle
         Allocate memory cache for all transformer blocks, return cache handle
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
         """
         """
         descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
         descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
-        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
+        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
             yield nested_pack(handles, descriptors)
             yield nested_pack(handles, descriptors)
 
 
     def _log_request(
     def _log_request(

+ 71 - 21
src/petals/server/memory_cache.py

@@ -12,12 +12,13 @@ import os
 import time
 import time
 from typing import AsyncContextManager, Dict, Optional, Sequence
 from typing import AsyncContextManager, Dict, Optional, Sequence
 
 
-import hivemind
+import async_timeout
 import torch
 import torch
-from hivemind.utils import TensorDescriptor, get_logger
+from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
 
 
 from petals.data_structures import Handle
 from petals.data_structures import Handle
 from petals.utils.asyncio import shield_and_wait
 from petals.utils.asyncio import shield_and_wait
+from petals.utils.misc import get_size_in_bytes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -25,11 +26,12 @@ logger = get_logger(__name__)
 class MemoryCache:
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
 
-    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
-        self.alloc_timeout = alloc_timeout
+        self.max_alloc_timeout = max_alloc_timeout
         self._lock_metadata = mp.Lock()
         self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self.runtime_pid = os.getpid()
         self.runtime_pid = os.getpid()
@@ -46,6 +48,14 @@ class MemoryCache:
     def current_size_bytes(self, value: int):
     def current_size_bytes(self, value: int):
         self._current_size.value = value
         self._current_size.value = value
 
 
+    @property
+    def enqueued_size_bytes(self) -> int:
+        return self._enqueued_size.value
+
+    @enqueued_size_bytes.setter
+    def enqueued_size_bytes(self, value: int):
+        self._enqueued_size.value = value
+
     @property
     @property
     def bytes_left(self) -> int:
     def bytes_left(self) -> int:
         return self.max_size_bytes - self.current_size_bytes
         return self.max_size_bytes - self.current_size_bytes
@@ -59,11 +69,14 @@ class MemoryCache:
         self._handle_counter.value = value
         self._handle_counter.value = value
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
+    async def allocate_cache(
+        self, *descriptors: TensorDescriptor, timeout: float
+    ) -> AsyncContextManager[Sequence[Handle]]:
         """
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
 
         :param descriptors: one or more tensors tensor of this size, dtype, etc
         :param descriptors: one or more tensors tensor of this size, dtype, etc
+        :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
 
 
         :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
         :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
           if not, it will count maximum tensor allocation across devices for the purposes of size limit
           if not, it will count maximum tensor allocation across devices for the purposes of size limit
@@ -73,6 +86,8 @@ class MemoryCache:
         """
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
         assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
+        if self.max_alloc_timeout is not None:
+            timeout = min(timeout, self.max_alloc_timeout)
         max_alloc_size = self.get_allocation_size(*descriptors)
         max_alloc_size = self.get_allocation_size(*descriptors)
 
 
         gib = 1024**3
         gib = 1024**3
@@ -83,10 +98,10 @@ class MemoryCache:
             f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
             f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
         )
         )
 
 
-        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
         try:
         try:
             handles = await shield_and_wait(alloc_task)
             handles = await shield_and_wait(alloc_task)
-            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
             yield handles
             yield handles
         finally:
         finally:
             self._free(max_alloc_size, alloc_task)
             self._free(max_alloc_size, alloc_task)
@@ -96,28 +111,62 @@ class MemoryCache:
         """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
         """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
         alloc_size_by_device = {}
         alloc_size_by_device = {}
         for descr in descriptors:
         for descr in descriptors:
-            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
             alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
             alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
         return max(alloc_size_by_device.values())
         return max(alloc_size_by_device.values())
 
 
-    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
+    async def _schedule_alloc(
+        self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
+    ) -> Sequence[Handle]:
         """
         """
         This method should be called inside asyncio.shield() because:
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
         """
         """
+        try:
+            async with self._wait_for_free_memory(alloc_size, timeout):
+                with self._lock_metadata:
+                    handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
+                    self.current_size_bytes += alloc_size
+                    self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
+                    self._pipe_send.send((handles, descriptors))
+                    return handles
+        except TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
 
 
+    @contextlib.asynccontextmanager
+    async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
+        start_time = time.perf_counter()
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
-            if self.current_size_bytes + alloc_size > self.max_size_bytes:
-                await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
-            with self._lock_metadata:
-                handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
-                self.current_size_bytes += alloc_size
-                self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
-                self._pipe_send.send((handles, descriptors))
-                return handles
-
-    def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
+
+        with self._enqueued_size.get_lock():
+            self._enqueued_size.value += alloc_size
+        allocated = False
+        try:
+            context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
+            # contextlib.AsyncExitStack() is used as a null context here
+            async with context_manager:
+                if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
+                    raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                async with enter_asynchronously(self._lock_acquire_memory):
+                    if self.current_size_bytes + alloc_size > self.max_size_bytes:
+                        if timeout == 0:
+                            raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                        elapsed_time = time.perf_counter() - start_time
+                        remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
+                        await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
+
+                allocated = True
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
+                yield
+        except asyncio.TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
+        finally:
+            if not allocated:
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
+
+    def _free(self, alloc_size: int, alloc_task: asyncio.Task):
         if alloc_task.exception() is not None:
         if alloc_task.exception() is not None:
             return
             return
         handles = alloc_task.result()
         handles = alloc_task.result()
@@ -133,9 +182,10 @@ class MemoryCache:
             raise AllocationFailed(
             raise AllocationFailed(
                 f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
                 f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
             )
             )
+        timeout = timeout if timeout != float("inf") else None
         deadline = None if timeout is None else time.perf_counter() + timeout
         deadline = None if timeout is None else time.perf_counter() + timeout
         while self.current_size_bytes + allocated_size > self.max_size_bytes:
         while self.current_size_bytes + allocated_size > self.max_size_bytes:
-            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            remaining_time = None if timeout is None else deadline - time.perf_counter()
             if not self._memory_freed_event.wait(remaining_time):
             if not self._memory_freed_event.wait(remaining_time):
                 raise AllocationFailed(
                 raise AllocationFailed(
                     f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
                     f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"

+ 1 - 1
src/petals/server/reachability.py

@@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
                 protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
                 protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
 
 
                 ready.set_result(True)
                 ready.set_result(True)
-                logger.info("Reachability service started")
+                logger.debug("Reachability service started")
 
 
                 async with protocol.serve(common_p2p):
                 async with protocol.serve(common_p2p):
                     await protocol._stop.wait()
                     await protocol._stop.wait()

+ 54 - 16
src/petals/server/server.py

@@ -3,13 +3,16 @@ from __future__ import annotations
 import gc
 import gc
 import math
 import math
 import multiprocessing as mp
 import multiprocessing as mp
+import os
 import random
 import random
 import threading
 import threading
 import time
 import time
 from typing import Dict, List, Optional, Sequence, Union
 from typing import Dict, List, Optional, Sequence, Union
 
 
 import hivemind
 import hivemind
+import psutil
 import torch
 import torch
+import torch.mps
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
@@ -19,7 +22,7 @@ from transformers import PretrainedConfig
 
 
 import petals
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
 from petals.server import block_selection
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -31,6 +34,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
+from petals.utils.misc import get_size_in_bytes
 from petals.utils.ping import PingAggregator
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
 from petals.utils.version import get_compatible_model_repo
@@ -59,12 +63,12 @@ class Server:
         min_batch_size: int = 1,
         min_batch_size: int = 1,
         max_batch_size: Optional[int] = None,
         max_batch_size: Optional[int] = None,
         max_chunk_size_bytes: int = 256 * 1024 * 1024,
         max_chunk_size_bytes: int = 256 * 1024 * 1024,
+        max_alloc_timeout: float = 600,
         attn_cache_tokens: Optional[int] = None,
         attn_cache_tokens: Optional[int] = None,
         torch_dtype: str = "auto",
         torch_dtype: str = "auto",
         revision: Optional[str] = None,
         revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
         max_disk_space: Optional[int] = None,
-        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         stats_report_interval: Optional[int] = None,
@@ -153,13 +157,25 @@ class Server:
         self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
         self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
 
 
         if device is None:
         if device is None:
-            device = "cuda" if torch.cuda.is_available() else "cpu"
+            if torch.cuda.is_available():
+                device = "cuda"
+            elif torch.backends.mps.is_available():
+                device = "mps"
+            else:
+                device = "cpu"
         device = torch.device(device)
         device = torch.device(device)
         if device.type == "cuda" and device.index is None:
         if device.type == "cuda" and device.index is None:
             device = torch.device(device.type, index=0)
             device = torch.device(device.type, index=0)
         self.device = device
         self.device = device
 
 
         torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
         torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
+        if device.type == "cpu" and torch_dtype == torch.float16:
+            raise ValueError(
+                f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
+            )
+        if device.type == "mps" and torch_dtype == torch.bfloat16:
+            logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
+            torch_dtype = torch.float16
         self.torch_dtype = torch_dtype
         self.torch_dtype = torch_dtype
 
 
         if tensor_parallel_devices is None:
         if tensor_parallel_devices is None:
@@ -185,13 +201,14 @@ class Server:
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
         self.max_chunk_size_bytes = max_chunk_size_bytes
         self.max_chunk_size_bytes = max_chunk_size_bytes
+        self.max_alloc_timeout = max_alloc_timeout
 
 
         # For attention cache in GPU or RAM
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
         if attn_cache_tokens is None:
             attn_cache_tokens = 32768 if is_multiquery_attn else 8192
             attn_cache_tokens = 32768 if is_multiquery_attn else 8192
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
         cache_values_per_block //= self.block_config.num_key_value_groups
-        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+        self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
 
 
         # For disk cache
         # For disk cache
         self.cache_dir = cache_dir
         self.cache_dir = cache_dir
@@ -217,8 +234,6 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
 
-        self.alloc_timeout = alloc_timeout
-
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
         if throughput in ["auto", "eval"]:
             throughput_info = get_server_throughput(
             throughput_info = get_server_throughput(
@@ -245,21 +260,26 @@ class Server:
             using_relay=reachable_via_relay,
             using_relay=reachable_via_relay,
             **throughput_info,
             **throughput_info,
         )
         )
+        self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
+        if not os.path.isdir(converted_model_name_or_path):
+            self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
 
 
         self.balance_quality = balance_quality
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_block_selection_delay = mean_block_selection_delay
 
 
+        self.module_container = None
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def _choose_num_blocks(self) -> int:
     def _choose_num_blocks(self) -> int:
-        assert self.device.type == "cuda", (
+        assert self.device.type in ("cuda", "mps"), (
             "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
             "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
             "CPU-only servers in the public swarm are discouraged since they are much slower"
             "CPU-only servers in the public swarm are discouraged since they are much slower"
         )
         )
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
 
 
         if num_devices > 1:
         if num_devices > 1:
+            assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
             memory_per_device = tuple(
             memory_per_device = tuple(
                 torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
                 torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
             )
             )
@@ -270,8 +290,10 @@ class Server:
                     "Please launch individual servers on each GPU or set --num_blocks manually to "
                     "Please launch individual servers on each GPU or set --num_blocks manually to "
                     "override this exception."
                     "override this exception."
                 )
                 )
-        else:
+        elif self.device.type == "cuda":
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
+        else:
+            total_memory = psutil.virtual_memory().total
 
 
         gib = 1024**3
         gib = 1024**3
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@@ -311,13 +333,14 @@ class Server:
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
                 attn_cache_bytes=self.attn_cache_bytes,
-                alloc_timeout=self.alloc_timeout,
                 server_info=self.server_info,
                 server_info=self.server_info,
+                model_info=self.model_info,
                 block_indices=block_indices,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
                 min_batch_size=self.min_batch_size,
                 max_batch_size=self.max_batch_size,
                 max_batch_size=self.max_batch_size,
                 max_chunk_size_bytes=self.max_chunk_size_bytes,
                 max_chunk_size_bytes=self.max_chunk_size_bytes,
+                max_alloc_timeout=self.max_alloc_timeout,
                 inference_max_length=self.inference_max_length,
                 inference_max_length=self.inference_max_length,
                 torch_dtype=self.torch_dtype,
                 torch_dtype=self.torch_dtype,
                 cache_dir=self.cache_dir,
                 cache_dir=self.cache_dir,
@@ -360,7 +383,7 @@ class Server:
             self._clean_memory_and_fds()
             self._clean_memory_and_fds()
 
 
     def _clean_memory_and_fds(self):
     def _clean_memory_and_fds(self):
-        del self.module_container
+        self.module_container = None
         gc.collect()  # In particular, this closes unused file descriptors
         gc.collect()  # In particular, this closes unused file descriptors
 
 
         if self.device.type == "cuda":
         if self.device.type == "cuda":
@@ -373,6 +396,8 @@ class Server:
                 f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
                 f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
                 f"{reserved_vram / gib:.1f} GiB reserved memory"
                 f"{reserved_vram / gib:.1f} GiB reserved memory"
             )
             )
+        elif self.device.type == "mps":
+            torch.mps.empty_cache()
 
 
     def _choose_blocks(self) -> List[int]:
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
         if self.strict_block_indices is not None:
@@ -391,8 +416,10 @@ class Server:
         module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
         module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
         return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
         return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
 
 
-    def shutdown(self):
+    def shutdown(self, timeout: Optional[float] = 5):
         self.stop.set()
         self.stop.set()
+        if self.module_container is not None and self.module_container.is_alive():
+            self.module_container.join(timeout)
 
 
         if self.reachability_protocol is not None:
         if self.reachability_protocol is not None:
             self.reachability_protocol.shutdown()
             self.reachability_protocol.shutdown()
@@ -413,12 +440,13 @@ class ModuleContainer(threading.Thread):
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
         attn_cache_bytes: int,
-        alloc_timeout: float,
         server_info: ServerInfo,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         block_indices: List[int],
         block_indices: List[int],
         min_batch_size: int,
         min_batch_size: int,
         max_batch_size: int,
         max_batch_size: int,
         max_chunk_size_bytes: int,
         max_chunk_size_bytes: int,
+        max_alloc_timeout: float,
         torch_dtype: torch.dtype,
         torch_dtype: torch.dtype,
         cache_dir: str,
         cache_dir: str,
         max_disk_space: int,
         max_disk_space: int,
@@ -434,13 +462,14 @@ class ModuleContainer(threading.Thread):
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
 
 
         server_info.state = ServerState.JOINING
         server_info.state = ServerState.JOINING
         dht_announcer = ModuleAnnouncerThread(
         dht_announcer = ModuleAnnouncerThread(
             module_uids,
             module_uids,
             dht,
             dht,
             server_info,
             server_info,
+            model_info,
             block_config=block_config,
             block_config=block_config,
             memory_cache=memory_cache,
             memory_cache=memory_cache,
             update_period=update_period,
             update_period=update_period,
@@ -649,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         module_uids: List[str],
         dht: DHT,
         dht: DHT,
         server_info: ServerInfo,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         *,
         *,
         block_config: PretrainedConfig,
         block_config: PretrainedConfig,
         memory_cache: MemoryCache,
         memory_cache: MemoryCache,
@@ -661,9 +691,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.module_uids = module_uids
         self.dht = dht
         self.dht = dht
         self.server_info = server_info
         self.server_info = server_info
+        self.model_info = model_info
         self.memory_cache = memory_cache
         self.memory_cache = memory_cache
 
 
-        self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
+        self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
         self.bytes_per_token //= block_config.num_key_value_groups
         self.bytes_per_token //= block_config.num_key_value_groups
 
 
         self.update_period = update_period
         self.update_period = update_period
@@ -671,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.trigger = threading.Event()
         self.trigger = threading.Event()
 
 
         self.max_pinged = max_pinged
         self.max_pinged = max_pinged
-        dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
+        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
         block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
         block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
         start_block, end_block = min(block_indices), max(block_indices) + 1
         start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
         self.ping_aggregator = PingAggregator(self.dht)
         self.ping_aggregator = PingAggregator(self.dht)
 
 
     def run(self) -> None:
     def run(self) -> None:
@@ -698,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
             )
             )
             if self.server_info.state == ServerState.OFFLINE:
             if self.server_info.state == ServerState.OFFLINE:
                 break
                 break
+            if not self.dht_prefix.startswith("_"):  # Not private
+                self.dht.store(
+                    key="_petals.models",
+                    subkey=self.dht_prefix,
+                    value=self.model_info.to_dict(),
+                    expiration_time=get_dht_time() + self.expiration,
+                )
 
 
             delay = self.update_period - (time.perf_counter() - start_time)
             delay = self.update_period - (time.perf_counter() - start_time)
             if delay < 0:
             if delay < 0:

+ 14 - 28
src/petals/server/task_pool.py

@@ -32,7 +32,7 @@ class Task:
         return self.future._uid
         return self.future._uid
 
 
 
 
-class PrioritizedTaskPool(TaskPoolBase):
+class PrioritizedTaskPool(threading.Thread):
     """
     """
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@@ -62,52 +62,41 @@ class PrioritizedTaskPool(TaskPoolBase):
         daemon=True,
         daemon=True,
         start=False,
         start=False,
     ):
     ):
-        super().__init__(process_func, daemon=daemon, name=name)
+        super().__init__(daemon=daemon, name=name)
+        self.process_func = process_func
+        # the lower the priority is, the more urgent it is to process this pool
+        self._priority = mp.Value(ctypes.c_double, 1.0)
+
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.device = device
         self.device = device
 
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
 
 
-        self._prioritizer_thread = threading.Thread(
-            name=self.name + "_prioritizer",
-            target=self._prioritize_tasks,
-            args=[self.submitted_tasks, self._ordered_tasks],
-            daemon=True,
-        )
         self._dispatched_tasks = {}
         self._dispatched_tasks = {}
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
 
 
-        self._stop = mp.Event()
         if start:
         if start:
             self.start()
             self.start()
 
 
-    @staticmethod
-    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+    def run(self):
         """Read tasks from incoming queue and put them into a local priority queue"""
         """Read tasks from incoming queue and put them into a local priority queue"""
         while True:
         while True:
-            task = submitted_tasks.get()
+            task = self.submitted_tasks.get()
             if task is None:
             if task is None:
                 logger.debug("Shutting down prioritizer thread")
                 logger.debug("Shutting down prioritizer thread")
                 break
                 break
 
 
-            ordered_tasks.put(task, block=True)
-
-    def start(self):
-        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
-        self._prioritizer_thread.start()
-        super().start()
+            self._ordered_tasks.put(task, block=True)
 
 
-    def shutdown(self, timeout: float = 3):
-        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
-        self._stop.set()
+    def terminate(self):
+        """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
+        self.shutdown()
 
 
-        self.join(timeout)
-        if self.is_alive():
-            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
-            self.terminate()
+    def shutdown(self):
+        self.submitted_tasks.put(None)  # Shuts down self.run()
 
 
     def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
     def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
@@ -161,9 +150,6 @@ class PrioritizedTaskPool(TaskPoolBase):
         else:
         else:
             task.future.set_exception(exception)
             task.future.set_exception(exception)
 
 
-    def run(self, *args, **kwargs):
-        self._stop.wait()
-
     @property
     @property
     def empty(self):
     def empty(self):
         return not self.batch_receiver.poll()
         return not self.batch_receiver.poll()

+ 12 - 6
src/petals/server/throughput.py

@@ -9,6 +9,7 @@ from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 from typing import Dict, Optional, Sequence, Union
 
 
 import torch
 import torch
+import torch.mps
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
@@ -207,14 +208,12 @@ def measure_compute_rps(
         elapsed = 0
         elapsed = 0
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
 
 
         start_time = time.perf_counter()
         start_time = time.perf_counter()
-        for step in range(n_steps):
+        for _ in range(n_steps):
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
         elapsed = time.perf_counter() - start_time
         elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
         device_rps = n_steps * n_tokens / elapsed
 
 
@@ -230,8 +229,15 @@ def measure_compute_rps(
     return device_rps
     return device_rps
 
 
 
 
+def synchronize(device: torch.device):
+    if device.type == "cuda":
+        torch.cuda.synchronize(device)
+    elif device.type == "mps":
+        torch.mps.synchronize()
+
+
 def get_device_name(device: torch.device) -> str:
 def get_device_name(device: torch.device) -> str:
-    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
+    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
 
 
 
 
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

+ 10 - 0
src/petals/utils/misc.py

@@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool:
     return tensor.numel() == 0
     return tensor.numel() == 0
 
 
 
 
+SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
+
+
+def get_size_in_bytes(dtype: torch.dtype) -> int:
+    if dtype in SPECIAL_DTYPE_SIZES:
+        return SPECIAL_DTYPE_SIZES[dtype]
+    get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
+    return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
+
+
 def docstring_from(source):
 def docstring_from(source):
     def add_docstring(dest):
     def add_docstring(dest):
         dest.__doc__ = source.__doc__
         dest.__doc__ = source.__doc__

+ 2 - 1
src/petals/utils/peft.py

@@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo
 from petals.server.block_utils import resolve_block_dtype
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import QuantType
 from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import get_size_in_bytes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block(
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
             )
             )
         adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
         adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
-    bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
+    bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
     return adapter_parameters * bytes_per_parameter
     return adapter_parameters * bytes_per_parameter

+ 184 - 0
tests/test_cache.py

@@ -0,0 +1,184 @@
+import asyncio
+import multiprocessing as mp
+import random
+import time
+from typing import Optional
+
+import pytest
+import pytest_asyncio  # make sure the module exists; otherwise the test will be skipped
+import torch
+from hivemind import TensorDescriptor
+
+from petals.server.memory_cache import AllocationFailed, MemoryCache
+from petals.utils.misc import get_size_in_bytes
+
+
+def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
+    if dtype is None:
+        dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
+    elem_size_bytes = get_size_in_bytes(dtype)
+    descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
+    return descr
+
+
+@pytest.mark.asyncio
+async def test_cache_timeout():
+    cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
+    cache.runtime_pid += 1  # pretend we're another process
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
+        pass
+
+    async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+            async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
+                        pass
+                assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
+                async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
+                    pass
+
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0):  # exceeds max timeout
+                        pass
+                assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
+
+            # test memory allocation when another task frees the memory
+            async def _klog_the_cache():
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+                    pass
+
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):  # exceeds max timeout
+                pass  # this memory should allocate once the background task clears the queue
+            assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+            # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            with pytest.raises(AllocationFailed):
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+                    pass  # this memory should allocate once the background task clears the queue
+            assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+
+@pytest.mark.asyncio
+async def test_unlimited_timeout():
+    cache = MemoryCache(max_size_bytes=1024)
+    cache.runtime_pid += 1  # pretend we're another process
+    t_start = time.perf_counter()
+
+    async def _klog_the_cache():
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+            await asyncio.sleep(0.5)
+
+    alloc_task = asyncio.create_task(_klog_the_cache())
+    await asyncio.sleep(0.1)
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
+        await alloc_task
+    assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
+
+
+@pytest.mark.asyncio
+async def test_cache_usage():
+    cache = MemoryCache(max_size_bytes=2048)
+    alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
+    pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
+    with pytest.raises(AssertionError):
+        async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
+            pass  # fails because cache must be allocated from another process
+
+    descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8))  # 768 bytes
+    descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64))  # 8 bytes
+    descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool))  # 33 bytes
+    descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64))  # 0 bytes
+    descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16))  # 1536 bytes
+    descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8))  # 1792 bytes
+
+    async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
+        loop = asyncio.get_event_loop()
+        async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
+            pipe_sender.send(handles)
+            await loop.run_in_executor(None, dealloc_event.wait)
+
+    async def _allocate_af():
+        alloc_event.wait()
+        allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
+        await allocate_a_task
+        allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache
+        await allocate_f_task
+
+    alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
+    alloc_process1.start()
+
+    async def _allocate_bcde():
+        alloc_event.wait()
+        await asyncio.sleep(0.1)  # ensure that the other tensor is always allocated (and sent through pipe) first
+        allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
+        allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
+        await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
+
+    alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
+    alloc_process2.start()
+    assert cache.current_size_bytes == 0
+    alloc_event.set()
+    (handle_a,) = pipe_receiver.recv()
+
+    handle_b, handle_c, handle_d = pipe_receiver.recv()
+
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tensor_a.dtype == torch.uint8
+        tensor_a[2:5] = torch.tensor((42, 43, 44))
+
+    with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
+        assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
+        assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
+        tensor_a += 1
+        tensor_b[...] = -1.337
+    assert cache.current_size_bytes == 809  # this checks a,b,c,d are allocated but b still awaits memory
+
+    dealloc_bcd_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 768  # only tensor a should be allocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a, handle_b):
+            pass  # one of handles (c) is deallocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_d):
+            pass  # handle_d is deallocated correctly, even though it is never used
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tuple(tensor_a[2:5]) == (43, 44, 45)
+
+    dealloc_a_event.set()
+    (handle_e,) = pipe_receiver.recv()  # e can finally be allocated
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1536  # tensor e should finally be able to allocate
+
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a):
+            pass  # tensor a is no longer allocated
+    with cache.use_cache(handle_e) as (tensor_e,):
+        assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
+
+    dealloc_e_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1792  # only tensor f is still allocated
+    dealloc_f_event.set()
+
+    alloc_process1.join()
+    alloc_process2.join()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 0
+    assert cache.current_size_bytes == 0
+    assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
+    assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"

+ 20 - 0
tests/test_full_model.py

@@ -149,3 +149,23 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n
     outputs = make_generate_calls(model, inputs, **options)
     outputs = make_generate_calls(model, inputs, **options)
     ref_outputs = ref_model.generate(inputs, **options)
     ref_outputs = ref_model.generate(inputs, **options)
     assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
     assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
+
+
+@pytest.mark.forked
+def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
+    assert inputs.keys() == {"input_ids", "attention_mask"}
+
+    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
+    ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
+    assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
+
+    with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
+        outputs = torch.cat(
+            [
+                model.generate(**inputs, max_new_tokens=2),
+                model.generate(None, max_new_tokens=max_new_tokens - 2),
+            ],
+            dim=1,
+        )
+    assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"

+ 29 - 18
tests/test_priority_pool.py

@@ -1,4 +1,5 @@
 import multiprocessing as mp
 import multiprocessing as mp
+import platform
 import time
 import time
 
 
 import pytest
 import pytest
@@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
 
 
 
 
+def _submit_tasks(runtime_ready, pools, results_valid):
+    runtime_ready.wait()
+
+    futures = []
+    futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+    time.sleep(0.01)
+    futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+    futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+    futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+    futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+    futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+    futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+    for i, f in enumerate(futures):
+        assert f.result()[0].item() == i**2
+    results_valid.set()
+
+
+@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
 @pytest.mark.forked
 @pytest.mark.forked
 def test_priority_pools():
 def test_priority_pools():
     outputs_queue = mp.SimpleQueue()
     outputs_queue = mp.SimpleQueue()
+    runtime_ready = mp.Event()
     results_valid = mp.Event()
     results_valid = mp.Event()
 
 
     def dummy_pool_func(args, kwargs):
     def dummy_pool_func(args, kwargs):
@@ -32,27 +54,14 @@ def test_priority_pools():
         PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
         PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
     )
     )
 
 
+    # Simulate requests coming from ConnectionHandlers
+    proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
+    proc.start()
+
     runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
     runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime.ready = runtime_ready
     runtime.start()
     runtime.start()
 
 
-    def process_tasks():
-        futures = []
-        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
-        time.sleep(0.01)
-        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
-        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
-        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
-        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
-        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
-        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
-        for i, f in enumerate(futures):
-            assert f.result()[0].item() == i**2
-        results_valid.set()
-
-    proc = mp.Process(target=process_tasks)
-    proc.start()
     proc.join()
     proc.join()
     assert results_valid.is_set()
     assert results_valid.is_set()
 
 
@@ -70,3 +79,5 @@ def test_priority_pools():
     #                                            3 - task with priority 2 from pool A
     #                                            3 - task with priority 2 from pool A
     #                                               4 - task with priority 10 from pool A
     #                                               4 - task with priority 10 from pool A
     #                                                  7 - task with priority 11 from pool B
     #                                                  7 - task with priority 11 from pool B
+
+    runtime.shutdown()

+ 1 - 1
tests/test_remote_sequential.py

@@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
 
 
     (outputs_ref * output_proj).sum().backward()
     (outputs_ref * output_proj).sum().backward()
     assert input_prompts_ref.grad is not None
     assert input_prompts_ref.grad is not None
-    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
     assert intermediate_prompts_ref.grad is not None
     assert intermediate_prompts_ref.grad is not None
     assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)
     assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)