__init__.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import asyncio
  2. import datetime
  3. import multiprocessing as mp
  4. from typing import Tuple, List, Optional
  5. from kademlia.network import Server
  6. from tesseract.client import RemoteExpert
  7. from tesseract.utils import run_in_background, repeated, SharedFuture, PickleSerializer
  8. class TesseractNetwork(mp.Process):
  9. UID_DELIMETER = '.' # splits expert uids over this delimeter
  10. HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
  11. make_key = "{}::{}".format
  12. def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False):
  13. super().__init__()
  14. self.port, self.initial_peers = port, initial_peers
  15. self._pipe, self.pipe = mp.Pipe(duplex=False)
  16. self.server = Server()
  17. if start:
  18. self.start()
  19. def run(self) -> None:
  20. loop = asyncio.new_event_loop()
  21. asyncio.set_event_loop(loop)
  22. loop.run_until_complete(self.server.listen(self.port))
  23. loop.run_until_complete(self.server.bootstrap(self.initial_peers))
  24. run_in_background(repeated(loop.run_forever))
  25. while True:
  26. method, args, kwargs = self._pipe.recv()
  27. getattr(self, method)(*args, **kwargs)
  28. def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
  29. """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
  30. future, _future = SharedFuture.make_pair()
  31. self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
  32. return future.result()
  33. def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
  34. loop = asyncio.get_event_loop()
  35. lookup_futures = [asyncio.run_coroutine_threadsafe(
  36. self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
  37. current_time = datetime.datetime.now()
  38. experts = [None] * len(uids)
  39. for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
  40. if lookup.result() is not None:
  41. (host, port), timestamp = PickleSerializer.loads(lookup.result())
  42. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  43. experts[i] = RemoteExpert(uid=uid, host=host, port=port)
  44. future.set_result(experts)
  45. def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
  46. """
  47. Make experts available to DHT; update timestamps if already available
  48. :param uids: a list of expert ids to update
  49. :param addr: hostname that can be used to call this expert
  50. :param port: port that can be used to call this expert
  51. :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
  52. """
  53. done_event = mp.Event() if wait_timeout else None
  54. self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event)))
  55. if done_event is not None:
  56. done_event.wait(wait_timeout)
  57. def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
  58. loop = asyncio.get_event_loop()
  59. timestamp = datetime.datetime.now()
  60. expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
  61. prefix_metadata = PickleSerializer.dumps(timestamp)
  62. unique_prefixes = set()
  63. for uid in uids:
  64. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
  65. uid_parts = uid.split(self.UID_DELIMETER)
  66. unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
  67. for prefix in unique_prefixes:
  68. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
  69. if done_event is not None:
  70. done_event.set()
  71. def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
  72. """
  73. Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
  74. :param prefixes: a list of uid prefixes ordered from highest to lowest priority
  75. :param k: return at most *this many* active prefixes
  76. :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
  77. :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
  78. :returns: a list of at most :k: prefixes that have at least one active expert each;
  79. """
  80. future, _future = SharedFuture.make_pair()
  81. self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
  82. max_prefetch=max_prefetch or k, future=_future)))
  83. return future.result()
  84. def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
  85. loop = asyncio.get_event_loop()
  86. lookup_prefetch = [asyncio.run_coroutine_threadsafe(
  87. self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
  88. current_time = datetime.datetime.now()
  89. active_prefixes = []
  90. for i, prefix in enumerate(prefixes):
  91. lookup = lookup_prefetch[i]
  92. if lookup.result() is not None:
  93. timestamp = PickleSerializer.loads(lookup.result())
  94. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  95. active_prefixes.append(prefix)
  96. if len(active_prefixes) >= k:
  97. future.set_result(active_prefixes)
  98. return
  99. # pre-dispatch the next request in line
  100. if len(lookup_prefetch) < len(prefixes):
  101. lookup_prefetch.append(
  102. asyncio.run_coroutine_threadsafe(self.server.get(
  103. self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
  104. )
  105. # could not find enough active prefixes; return what we can
  106. future.set_result(active_prefixes)