shared_future.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import multiprocessing as mp
  2. import multiprocessing.connection
  3. from concurrent.futures import Future, CancelledError
  4. from warnings import warn
  5. class SharedFuture(Future):
  6. """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
  7. STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
  8. STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES
  9. def __init__(self, connection: mp.connection.Connection):
  10. """ manually create MPFuture. Please use MPFuture.make_pair instead """
  11. self.connection = connection
  12. self.state = self.STATE_PENDING
  13. self._result = None
  14. self._exception = None
  15. @classmethod
  16. def make_pair(cls):
  17. """ Create a pair of linked futures to be used in two processes """
  18. connection1, connection2 = mp.Pipe()
  19. return cls(connection1), cls(connection2)
  20. def _recv(self, timeout):
  21. if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
  22. if not self.connection.poll(timeout):
  23. raise TimeoutError()
  24. try:
  25. status, payload = self.connection.recv()
  26. except BrokenPipeError as e:
  27. status, payload = self.STATE_EXCEPTION, e
  28. assert status in self.STATES
  29. self.state = status
  30. if status == self.STATE_FINISHED:
  31. self._result = payload
  32. elif status == self.STATE_EXCEPTION:
  33. self._exception = payload
  34. elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
  35. pass # only update self.state
  36. else:
  37. raise ValueError("Result status should not be self.STATE_PENDING")
  38. def set_result(self, result):
  39. try:
  40. self.state, self._result = self.STATE_FINISHED, result
  41. self.connection.send((self.STATE_FINISHED, result))
  42. return True
  43. except BrokenPipeError:
  44. return False
  45. def set_exception(self, exception: BaseException):
  46. try:
  47. self.state, self._exception = self.STATE_EXCEPTION, exception
  48. self.connection.send((self.STATE_EXCEPTION, exception))
  49. return True
  50. except BrokenPipeError:
  51. return False
  52. def set_running_or_notify_cancel(self):
  53. return True
  54. def cancel(self):
  55. raise NotImplementedError()
  56. def result(self, timeout=None):
  57. self._recv(timeout)
  58. if self.state == self.STATE_FINISHED:
  59. return self._result
  60. elif self.state == self.STATE_EXCEPTION:
  61. raise self._exception
  62. else:
  63. assert self.state == self.STATE_CANCELLED
  64. raise CancelledError()
  65. def exception(self, timeout=None):
  66. self._recv(timeout)
  67. return self._exception
  68. def done(self):
  69. return self.state in (self.STATE_FINISHED, self.STATE_EXCEPTION, self.STATE_CANCELLED)
  70. def running(self):
  71. return self.state == self.STATE_RUNNING
  72. def cancelled(self):
  73. warn("cancelled not implemented")
  74. return False
  75. def add_done_callback(self, callback):
  76. raise NotImplementedError()
  77. def __repr__(self):
  78. try:
  79. self._recv(timeout=0)
  80. except TimeoutError:
  81. pass
  82. if self.state == self.STATE_FINISHED:
  83. return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
  84. elif self.state == self.STATE_EXCEPTION:
  85. return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
  86. else:
  87. return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)