__init__.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import asyncio
  2. import datetime
  3. import multiprocessing as mp
  4. import warnings
  5. from typing import Tuple, List, Optional
  6. from kademlia.network import Server
  7. from tesseract.client import RemoteExpert
  8. from tesseract.utils import run_forever, SharedFuture, PickleSerializer
  9. class TesseractNetwork(mp.Process):
  10. UID_DELIMETER = '.' # splits expert uids over this delimeter
  11. HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
  12. make_key = "{}::{}".format
  13. def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False, daemon=True):
  14. super().__init__()
  15. self.port, self.initial_peers = port, initial_peers
  16. self._pipe, self.pipe = mp.Pipe(duplex=False)
  17. self.ready = mp.Event()
  18. self.server = Server()
  19. self.daemon = daemon
  20. if start:
  21. self.run_in_background(await_ready=True)
  22. def run(self) -> None:
  23. loop = asyncio.new_event_loop()
  24. asyncio.set_event_loop(loop)
  25. loop.run_until_complete(self.server.listen(self.port))
  26. loop.run_until_complete(self.server.bootstrap(self.initial_peers))
  27. run_forever(loop.run_forever)
  28. self.ready.set()
  29. while True:
  30. try:
  31. method, args, kwargs = self._pipe.recv()
  32. except:
  33. print('>>', self.pid, flush=True)
  34. getattr(self, method)(*args, **kwargs)
  35. def run_in_background(self, await_ready=True, timeout=None):
  36. """
  37. Starts TesseractNetwork in a background process. if await_ready, this method will wait until background network
  38. is ready to process incoming requests or for :timeout: seconds max.
  39. """
  40. self.start()
  41. if await_ready and not self.ready.wait(timeout=timeout):
  42. raise TimeoutError("TesseractServer didn't notify .ready in {timeout} seconds")
  43. def shutdown(self) -> None:
  44. """ Shuts down the network process """
  45. if self.is_alive():
  46. self.kill()
  47. else:
  48. warnings.warn("Network shutdown has no effect: network process is already not alive")
  49. def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
  50. """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
  51. future, _future = SharedFuture.make_pair()
  52. self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
  53. return future.result()
  54. def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
  55. loop = asyncio.get_event_loop()
  56. lookup_futures = [asyncio.run_coroutine_threadsafe(
  57. self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
  58. current_time = datetime.datetime.now()
  59. experts = [None] * len(uids)
  60. for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
  61. if lookup.result() is not None:
  62. (host, port), timestamp = PickleSerializer.loads(lookup.result())
  63. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  64. experts[i] = RemoteExpert(uid=uid, host=host, port=port)
  65. future.set_result(experts)
  66. def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
  67. """
  68. Make experts available to DHT; update timestamps if already available
  69. :param uids: a list of expert ids to update
  70. :param addr: hostname that can be used to call this expert
  71. :param port: port that can be used to call this expert
  72. :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
  73. """
  74. done_event = mp.Event() if wait_timeout else None
  75. self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
  76. if done_event is not None:
  77. done_event.wait(wait_timeout)
  78. def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
  79. loop = asyncio.get_event_loop()
  80. timestamp = datetime.datetime.now()
  81. expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
  82. prefix_metadata = PickleSerializer.dumps(timestamp)
  83. unique_prefixes = set()
  84. for uid in uids:
  85. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
  86. uid_parts = uid.split(self.UID_DELIMETER)
  87. unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
  88. for prefix in unique_prefixes:
  89. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
  90. if done_event is not None:
  91. done_event.set()
  92. def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
  93. """
  94. Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
  95. :param prefixes: a list of uid prefixes ordered from highest to lowest priority
  96. :param k: return at most *this many* active prefixes
  97. :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
  98. :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
  99. :returns: a list of at most :k: prefixes that have at least one active expert each;
  100. """
  101. future, _future = SharedFuture.make_pair()
  102. self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
  103. max_prefetch=max_prefetch or k, future=_future)))
  104. return future.result()
  105. def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
  106. loop = asyncio.get_event_loop()
  107. lookup_prefetch = [asyncio.run_coroutine_threadsafe(
  108. self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
  109. current_time = datetime.datetime.now()
  110. active_prefixes = []
  111. for i, prefix in enumerate(prefixes):
  112. lookup = lookup_prefetch[i]
  113. if lookup.result() is not None:
  114. timestamp = PickleSerializer.loads(lookup.result())
  115. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  116. active_prefixes.append(prefix)
  117. if len(active_prefixes) >= k:
  118. future.set_result(active_prefixes)
  119. return
  120. # pre-dispatch the next request in line
  121. if len(lookup_prefetch) < len(prefixes):
  122. lookup_prefetch.append(
  123. asyncio.run_coroutine_threadsafe(self.server.get(
  124. self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
  125. )
  126. # could not find enough active prefixes; return what we can
  127. future.set_result(active_prefixes)