__init__.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import asyncio
  2. import datetime
  3. import multiprocessing as mp
  4. import warnings
  5. from typing import Tuple, List, Optional
  6. from kademlia.network import Server
  7. from tesseract.client import RemoteExpert
  8. from tesseract.utils import run_forever, SharedFuture, PickleSerializer
  9. class TesseractNetwork(mp.Process):
  10. UID_DELIMETER = '.' # splits expert uids over this delimeter
  11. HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
  12. make_key = "{}::{}".format
  13. def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False, daemon=True):
  14. super().__init__()
  15. self.port, self.initial_peers = port, initial_peers
  16. self._pipe, self.pipe = mp.Pipe(duplex=False)
  17. self.ready = mp.Event()
  18. self.server = Server()
  19. self.daemon = daemon
  20. if start:
  21. self.run_in_background(await_ready=True)
  22. def run(self) -> None:
  23. loop = asyncio.new_event_loop()
  24. asyncio.set_event_loop(loop)
  25. loop.run_until_complete(self.server.listen(self.port))
  26. loop.run_until_complete(self.server.bootstrap(self.initial_peers))
  27. run_forever(loop.run_forever)
  28. self.ready.set()
  29. while True:
  30. method, args, kwargs = self._pipe.recv()
  31. getattr(self, method)(*args, **kwargs)
  32. def run_in_background(self, await_ready=True, timeout=None):
  33. """
  34. Starts TesseractNetwork in a background process. if await_ready, this method will wait until background network
  35. is ready to process incoming requests or for :timeout: seconds max.
  36. """
  37. self.start()
  38. if await_ready and not self.ready.wait(timeout=timeout):
  39. raise TimeoutError("TesseractServer didn't notify .ready in {timeout} seconds")
  40. def shutdown(self) -> None:
  41. """ Shuts down the network process """
  42. if self.is_alive():
  43. self.kill()
  44. else:
  45. warnings.warn("Network shutdown has no effect: network process is already not alive")
  46. def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
  47. """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
  48. future, _future = SharedFuture.make_pair()
  49. self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
  50. return future.result()
  51. def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
  52. loop = asyncio.get_event_loop()
  53. lookup_futures = [asyncio.run_coroutine_threadsafe(
  54. self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
  55. current_time = datetime.datetime.now()
  56. experts = [None] * len(uids)
  57. for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
  58. if lookup.result() is not None:
  59. (host, port), timestamp = PickleSerializer.loads(lookup.result())
  60. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  61. experts[i] = RemoteExpert(uid=uid, host=host, port=port)
  62. future.set_result(experts)
  63. def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
  64. """
  65. Make experts available to DHT; update timestamps if already available
  66. :param uids: a list of expert ids to update
  67. :param addr: hostname that can be used to call this expert
  68. :param port: port that can be used to call this expert
  69. :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
  70. """
  71. done_event = mp.Event() if wait_timeout else None
  72. self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
  73. if done_event is not None:
  74. done_event.wait(wait_timeout)
  75. def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
  76. loop = asyncio.get_event_loop()
  77. timestamp = datetime.datetime.now()
  78. expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
  79. prefix_metadata = PickleSerializer.dumps(timestamp)
  80. unique_prefixes = set()
  81. for uid in uids:
  82. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
  83. uid_parts = uid.split(self.UID_DELIMETER)
  84. unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
  85. for prefix in unique_prefixes:
  86. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
  87. if done_event is not None:
  88. done_event.set()
  89. def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
  90. """
  91. Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
  92. :param prefixes: a list of uid prefixes ordered from highest to lowest priority
  93. :param k: return at most *this many* active prefixes
  94. :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
  95. :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
  96. :returns: a list of at most :k: prefixes that have at least one active expert each;
  97. """
  98. future, _future = SharedFuture.make_pair()
  99. self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
  100. max_prefetch=max_prefetch or k, future=_future)))
  101. return future.result()
  102. def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
  103. loop = asyncio.get_event_loop()
  104. lookup_prefetch = [asyncio.run_coroutine_threadsafe(
  105. self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
  106. current_time = datetime.datetime.now()
  107. active_prefixes = []
  108. for i, prefix in enumerate(prefixes):
  109. lookup = lookup_prefetch[i]
  110. if lookup.result() is not None:
  111. timestamp = PickleSerializer.loads(lookup.result())
  112. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  113. active_prefixes.append(prefix)
  114. if len(active_prefixes) >= k:
  115. future.set_result(active_prefixes)
  116. return
  117. # pre-dispatch the next request in line
  118. if len(lookup_prefetch) < len(prefixes):
  119. lookup_prefetch.append(
  120. asyncio.run_coroutine_threadsafe(self.server.get(
  121. self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
  122. )
  123. # could not find enough active prefixes; return what we can
  124. future.set_result(active_prefixes)