test_priority_pool.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import multiprocessing as mp
  2. import time
  3. import pytest
  4. import torch
  5. from hivemind.moe.server.runtime import Runtime
  6. from petals.server.task_pool import PrioritizedTaskPool
  7. @pytest.mark.forked
  8. def test_priority_pools():
  9. outputs_queue = mp.SimpleQueue()
  10. results_valid = mp.Event()
  11. def dummy_pool_func(args, kwargs):
  12. (x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly
  13. time.sleep(0.1)
  14. y = x**2
  15. outputs_queue.put((x, y))
  16. return (y,)
  17. class DummyBackend:
  18. def __init__(self, pools):
  19. self.pools = pools
  20. def get_pools(self):
  21. return self.pools
  22. pools = (
  23. PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
  24. PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
  25. )
  26. runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
  27. runtime.start()
  28. def process_tasks():
  29. futures = []
  30. futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
  31. futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
  32. time.sleep(0.01)
  33. futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
  34. futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
  35. futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
  36. futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
  37. futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
  38. futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
  39. futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
  40. for i, f in enumerate(futures):
  41. assert f.result()[0].item() == i**2
  42. results_valid.set()
  43. proc = mp.Process(target=process_tasks)
  44. proc.start()
  45. proc.join()
  46. assert results_valid.is_set()
  47. ordered_outputs = []
  48. while not outputs_queue.empty():
  49. ordered_outputs.append(outputs_queue.get()[0].item())
  50. assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
  51. # 0 - first batch is loaded immediately, before everything else
  52. # 5 - highest priority task overall
  53. # 1 - first of several tasks with equal lowest priority (1)
  54. # 2 - second earliest task with priority 1, fetched from pool B
  55. # 6 - third earliest task with priority 1, fetched from pool A again
  56. # 8 - last priority-1 task, pool B
  57. # 3 - task with priority 2 from pool A
  58. # 4 - task with priority 10 from pool A
  59. # 7 - task with priority 11 from pool B