test_priority_pool.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import multiprocessing as mp
  2. import platform
  3. import time
  4. import pytest
  5. import torch
  6. from petals.server.server import RuntimeWithDeduplicatedPools
  7. from petals.server.task_pool import PrioritizedTaskPool
  8. def _submit_tasks(runtime_ready, pools, results_valid):
  9. runtime_ready.wait()
  10. futures = []
  11. futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
  12. futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
  13. time.sleep(0.01)
  14. futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
  15. futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
  16. futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
  17. futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
  18. futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
  19. futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
  20. futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
  21. for i, f in enumerate(futures):
  22. assert f.result()[0].item() == i**2
  23. results_valid.set()
  24. @pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
  25. @pytest.mark.forked
  26. def test_priority_pools():
  27. outputs_queue = mp.SimpleQueue()
  28. runtime_ready = mp.Event()
  29. results_valid = mp.Event()
  30. def dummy_pool_func(x):
  31. time.sleep(0.1)
  32. y = x**2
  33. outputs_queue.put((x, y))
  34. return (y,)
  35. class DummyBackend:
  36. def __init__(self, pools):
  37. self.pools = pools
  38. def get_pools(self):
  39. return self.pools
  40. pools = (
  41. PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
  42. PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
  43. )
  44. # Simulate requests coming from ConnectionHandlers
  45. proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
  46. proc.start()
  47. runtime = RuntimeWithDeduplicatedPools(
  48. {str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0
  49. )
  50. runtime.ready = runtime_ready
  51. runtime.start()
  52. proc.join()
  53. assert results_valid.is_set()
  54. ordered_outputs = []
  55. while not outputs_queue.empty():
  56. ordered_outputs.append(outputs_queue.get()[0].item())
  57. assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
  58. # 0 - first batch is loaded immediately, before everything else
  59. # 5 - highest priority task overall
  60. # 1 - first of several tasks with equal lowest priority (1)
  61. # 2 - second earliest task with priority 1, fetched from pool B
  62. # 6 - third earliest task with priority 1, fetched from pool A again
  63. # 8 - last priority-1 task, pool B
  64. # 3 - task with priority 2 from pool A
  65. # 4 - task with priority 10 from pool A
  66. # 7 - task with priority 11 from pool B
  67. runtime.shutdown()