Bladeren bron

Support macOS (#477)

This PR makes both clients and servers work on macOS. Specifically, it:

- Follows https://github.com/learning-at-home/hivemind/pull/586 to run a macOS-compatible `p2pd` binary (both x86-64 and ARM64 are supported)
- Fixes forking issues and tests on macOS, Python 3.10+
- Introduces basic support for serving model blocks on Apple M1/M2 GPUs (torch.mps)
- Increases max number of open files by default (it's not enough on Linux and is really small on macOS)
Alexander Borzunov 2 jaren geleden
bovenliggende
commit
26ebbfe8f0

+ 18 - 20
.github/workflows/run-tests.yaml

@@ -7,20 +7,21 @@ on:
 
 
 jobs:
 jobs:
   run-tests:
   run-tests:
-    runs-on: ubuntu-latest
     strategy:
     strategy:
       matrix:
       matrix:
         include:
         include:
-          - { model: 'bigscience/bloom-560m', python-version: '3.8' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.9' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.10' }
-          - { model: 'bigscience/bloom-560m', python-version: '3.11' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
-          - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
+          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false
       fail-fast: false
+    runs-on: ${{ matrix.os }}-latest
     timeout-minutes: 15
     timeout-minutes: 15
     steps:
     steps:
       - name: Increase swap space
       - name: Increase swap space
+        if: ${{ matrix.os == 'ubuntu' }}
         uses: pierotofy/set-swap-space@master
         uses: pierotofy/set-swap-space@master
         with:
         with:
           swap-size-gb: 10
           swap-size-gb: 10
@@ -47,12 +48,7 @@ jobs:
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
           export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
 
-          # [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
-
-          bash -c 'while true; do free -h && sleep 30s; done' &
-          RAM_WATCH_PID=$!
-
-          # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
+          # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
 
           python -m petals.cli.run_dht \
           python -m petals.cli.run_dht \
             --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
             --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@@ -61,7 +57,7 @@ jobs:
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
           # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
 
 
-          sleep 5  # wait for DHT init
+          until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
 
           python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
           python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
             --mean_balance_check_period 10 \
             --mean_balance_check_period 10 \
@@ -95,11 +91,15 @@ jobs:
           sleep 30  # wait for servers to eval throughput, download layers, and rebalance
           sleep 30  # wait for servers to eval throughput, download layers, and rebalance
           kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init
           kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init
 
 
-          # [Step 3] Run PyTest
+          # [Step 2] Run PyTest
+
+          # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+          export no_proxy=*
+          export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
 
 
           pytest tests --durations=0 --durations-min=1.0 -v
           pytest tests --durations=0 --durations-min=1.0 -v
 
 
-          # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
+          # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
 
 
           python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
           python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3
             --seq_len 3
@@ -110,9 +110,7 @@ jobs:
           python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
           python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
             --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
             --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
 
 
-          # [Step 5] Clean up
-
-          kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived tests
+          # [Step 4] Clean up
 
 
-          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
+          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           echo "Done!"
           echo "Done!"

+ 12 - 4
README.md

@@ -51,7 +51,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2
 
 
 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
 
 
-🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
+🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 
 ```bash
 ```bash
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
@@ -59,12 +59,20 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach
     python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
     python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
 ```
 ```
 
 
+🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
+
+```bash
+brew install python
+python3 -m pip install git+https://github.com/bigscience-workshop/petals
+python3 -m petals.cli.run_server petals-team/StableBeluga2
+```
+
 <p align="center">
 <p align="center">
-    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (using multiple GPUs, starting on boot, etc.)
-    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
-    💬 &nbsp;<b><a href="https://discord.gg/X7DgtxgMhc">Ask for help in Discord</a></b>
+    📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
 </p>
 </p>
 
 
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
+
 🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
 🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
 
 
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).

+ 2 - 1
setup.cfg

@@ -18,6 +18,7 @@ classifiers =
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: 3.10
+    Programming Language :: Python :: 3.11
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Artificial Intelligence
     Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -39,7 +40,7 @@ install_requires =
     transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind==1.1.9
+    hivemind @ git+https://github.com/learning-at-home/hivemind
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
     humanfriendly
     humanfriendly
     async-timeout>=4.0.2
     async-timeout>=4.0.2

+ 6 - 0
src/petals/__init__.py

@@ -1,7 +1,13 @@
 import os
 import os
+import platform
 
 
 os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
 os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
 
 
+if platform.system() == "Darwin":
+    # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
+    os.environ.setdefault("no_proxy", "*")
+    os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
+
 import hivemind
 import hivemind
 import transformers
 import transformers
 from packaging import version
 from packaging import version

+ 14 - 6
src/petals/cli/run_server.py

@@ -1,8 +1,10 @@
 import argparse
 import argparse
+import logging
 
 
 import configargparse
 import configargparse
+import torch
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.limits import increase_file_limit
+from hivemind.utils import limits
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 from humanfriendly import parse_size
 
 
@@ -127,9 +129,9 @@ def main():
     group.add_argument('--new_swarm', action='store_true',
     group.add_argument('--new_swarm', action='store_true',
                        help='Start a new private swarm (i.e., do not connect to any initial peers)')
                        help='Start a new private swarm (i.e., do not connect to any initial peers)')
 
 
-    parser.add_argument('--increase_file_limit', action='store_true',
-                        help='On *nix, this will increase the max number of processes '
-                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--increase_file_limit', type=int, default=4096,
+                        help='On *nix, increase the max number of files a server can open '
+                             'before hitting "Too many open files" (set to zero to keep the system limit)')
     parser.add_argument('--stats_report_interval', type=int, required=False,
     parser.add_argument('--stats_report_interval', type=int, required=False,
                         help='Interval between two reports of batch processing performance statistics')
                         help='Interval between two reports of batch processing performance statistics')
 
 
@@ -185,8 +187,10 @@ def main():
 
 
     args["startup_timeout"] = args.pop("daemon_startup_timeout")
     args["startup_timeout"] = args.pop("daemon_startup_timeout")
 
 
-    if args.pop("increase_file_limit"):
-        increase_file_limit()
+    file_limit = args.pop("increase_file_limit")
+    if file_limit:
+        limits.logger.setLevel(logging.WARNING)
+        limits.increase_file_limit(file_limit, file_limit)
 
 
     compression_type = args.pop("compression").upper()
     compression_type = args.pop("compression").upper()
     compression = getattr(CompressionType, compression_type)
     compression = getattr(CompressionType, compression_type)
@@ -207,6 +211,10 @@ def main():
 
 
     validate_version()
     validate_version()
 
 
+    if not torch.backends.openmp.is_available():
+        # Necessary to prevent the server from freezing after forks
+        torch.set_num_threads(1)
+
     server = Server(
     server = Server(
         **args,
         **args,
         host_maddrs=host_maddrs,
         host_maddrs=host_maddrs,

+ 1 - 1
src/petals/server/reachability.py

@@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
                 protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
                 protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
 
 
                 ready.set_result(True)
                 ready.set_result(True)
-                logger.info("Reachability service started")
+                logger.debug("Reachability service started")
 
 
                 async with protocol.serve(common_p2p):
                 async with protocol.serve(common_p2p):
                     await protocol._stop.wait()
                     await protocol._stop.wait()

+ 22 - 3
src/petals/server/server.py

@@ -9,7 +9,9 @@ import time
 from typing import Dict, List, Optional, Sequence, Union
 from typing import Dict, List, Optional, Sequence, Union
 
 
 import hivemind
 import hivemind
+import psutil
 import torch
 import torch
+import torch.mps
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
@@ -154,13 +156,25 @@ class Server:
         self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
         self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
 
 
         if device is None:
         if device is None:
-            device = "cuda" if torch.cuda.is_available() else "cpu"
+            if torch.cuda.is_available():
+                device = "cuda"
+            elif torch.backends.mps.is_available():
+                device = "mps"
+            else:
+                device = "cpu"
         device = torch.device(device)
         device = torch.device(device)
         if device.type == "cuda" and device.index is None:
         if device.type == "cuda" and device.index is None:
             device = torch.device(device.type, index=0)
             device = torch.device(device.type, index=0)
         self.device = device
         self.device = device
 
 
         torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
         torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
+        if device.type == "cpu" and torch_dtype == torch.float16:
+            raise ValueError(
+                f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
+            )
+        if device.type == "mps" and torch_dtype == torch.bfloat16:
+            logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
+            torch_dtype = torch.float16
         self.torch_dtype = torch_dtype
         self.torch_dtype = torch_dtype
 
 
         if tensor_parallel_devices is None:
         if tensor_parallel_devices is None:
@@ -253,13 +267,14 @@ class Server:
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def _choose_num_blocks(self) -> int:
     def _choose_num_blocks(self) -> int:
-        assert self.device.type == "cuda", (
+        assert self.device.type in ("cuda", "mps"), (
             "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
             "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
             "CPU-only servers in the public swarm are discouraged since they are much slower"
             "CPU-only servers in the public swarm are discouraged since they are much slower"
         )
         )
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
 
 
         if num_devices > 1:
         if num_devices > 1:
+            assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
             memory_per_device = tuple(
             memory_per_device = tuple(
                 torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
                 torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
             )
             )
@@ -270,8 +285,10 @@ class Server:
                     "Please launch individual servers on each GPU or set --num_blocks manually to "
                     "Please launch individual servers on each GPU or set --num_blocks manually to "
                     "override this exception."
                     "override this exception."
                 )
                 )
-        else:
+        elif self.device.type == "cuda":
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
+        else:
+            total_memory = psutil.virtual_memory().total
 
 
         gib = 1024**3
         gib = 1024**3
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@@ -373,6 +390,8 @@ class Server:
                 f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
                 f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
                 f"{reserved_vram / gib:.1f} GiB reserved memory"
                 f"{reserved_vram / gib:.1f} GiB reserved memory"
             )
             )
+        elif self.device.type == "mps":
+            torch.mps.empty_cache()
 
 
     def _choose_blocks(self) -> List[int]:
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
         if self.strict_block_indices is not None:

+ 12 - 6
src/petals/server/throughput.py

@@ -9,6 +9,7 @@ from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 from typing import Dict, Optional, Sequence, Union
 
 
 import torch
 import torch
+import torch.mps
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
@@ -207,14 +208,12 @@ def measure_compute_rps(
         elapsed = 0
         elapsed = 0
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
 
 
         start_time = time.perf_counter()
         start_time = time.perf_counter()
-        for step in range(n_steps):
+        for _ in range(n_steps):
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
-        if device.type == "cuda":
-            torch.cuda.synchronize(device)
+        synchronize(device)
         elapsed = time.perf_counter() - start_time
         elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
         device_rps = n_steps * n_tokens / elapsed
 
 
@@ -230,8 +229,15 @@ def measure_compute_rps(
     return device_rps
     return device_rps
 
 
 
 
+def synchronize(device: torch.device):
+    if device.type == "cuda":
+        torch.cuda.synchronize(device)
+    elif device.type == "mps":
+        torch.mps.synchronize()
+
+
 def get_device_name(device: torch.device) -> str:
 def get_device_name(device: torch.device) -> str:
-    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
+    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
 
 
 
 
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

+ 2 - 2
tests/test_cache.py

@@ -118,7 +118,7 @@ async def test_cache_usage():
         allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache
         allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache
         await allocate_f_task
         await allocate_f_task
 
 
-    alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
+    alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
     alloc_process1.start()
     alloc_process1.start()
 
 
     async def _allocate_bcde():
     async def _allocate_bcde():
@@ -128,7 +128,7 @@ async def test_cache_usage():
         allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
         allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
         await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
         await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
 
 
-    alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
+    alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
     alloc_process2.start()
     alloc_process2.start()
     assert cache.current_size_bytes == 0
     assert cache.current_size_bytes == 0
     alloc_event.set()
     alloc_event.set()

+ 29 - 18
tests/test_priority_pool.py

@@ -1,4 +1,5 @@
 import multiprocessing as mp
 import multiprocessing as mp
+import platform
 import time
 import time
 
 
 import pytest
 import pytest
@@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
 
 
 
 
+def _submit_tasks(runtime_ready, pools, results_valid):
+    runtime_ready.wait()
+
+    futures = []
+    futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+    time.sleep(0.01)
+    futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+    futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+    futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+    futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+    futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+    futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+    futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+    for i, f in enumerate(futures):
+        assert f.result()[0].item() == i**2
+    results_valid.set()
+
+
+@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
 @pytest.mark.forked
 @pytest.mark.forked
 def test_priority_pools():
 def test_priority_pools():
     outputs_queue = mp.SimpleQueue()
     outputs_queue = mp.SimpleQueue()
+    runtime_ready = mp.Event()
     results_valid = mp.Event()
     results_valid = mp.Event()
 
 
     def dummy_pool_func(x):
     def dummy_pool_func(x):
@@ -31,27 +53,14 @@ def test_priority_pools():
         PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
         PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
     )
     )
 
 
+    # Simulate requests coming from ConnectionHandlers
+    proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
+    proc.start()
+
     runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
     runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime.ready = runtime_ready
     runtime.start()
     runtime.start()
 
 
-    def process_tasks():
-        futures = []
-        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
-        time.sleep(0.01)
-        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
-        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
-        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
-        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
-        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
-        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
-        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
-        for i, f in enumerate(futures):
-            assert f.result()[0].item() == i**2
-        results_valid.set()
-
-    proc = mp.Process(target=process_tasks)
-    proc.start()
     proc.join()
     proc.join()
     assert results_valid.is_set()
     assert results_valid.is_set()
 
 
@@ -69,3 +78,5 @@ def test_priority_pools():
     #                                            3 - task with priority 2 from pool A
     #                                            3 - task with priority 2 from pool A
     #                                               4 - task with priority 10 from pool A
     #                                               4 - task with priority 10 from pool A
     #                                                  7 - task with priority 11 from pool B
     #                                                  7 - task with priority 11 from pool B
+
+    runtime.shutdown()