__init__.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import asyncio
  2. import datetime
  3. import multiprocessing as mp
  4. import warnings
  5. from typing import Tuple, List, Optional
  6. from kademlia.network import Server
  7. from tesseract.client import RemoteExpert
  8. from tesseract.utils import run_in_background, repeated, SharedFuture, PickleSerializer
  9. class TesseractNetwork(mp.Process):
  10. UID_DELIMETER = '.' # splits expert uids over this delimeter
  11. HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
  12. make_key = "{}::{}".format
  13. def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False):
  14. super().__init__()
  15. self.port, self.initial_peers = port, initial_peers
  16. self._pipe, self.pipe = mp.Pipe(duplex=False)
  17. self.server = Server()
  18. if start:
  19. self.start()
  20. def run(self) -> None:
  21. loop = asyncio.new_event_loop()
  22. asyncio.set_event_loop(loop)
  23. loop.run_until_complete(self.server.listen(self.port))
  24. loop.run_until_complete(self.server.bootstrap(self.initial_peers))
  25. run_in_background(repeated(loop.run_forever))
  26. while True:
  27. method, args, kwargs = self._pipe.recv()
  28. getattr(self, method)(*args, **kwargs)
  29. def shutdown(self) -> None:
  30. """ Shuts down the network process """
  31. warnings.warn("TODO shutdown network gracefully")
  32. self.terminate()
  33. def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
  34. """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
  35. future, _future = SharedFuture.make_pair()
  36. self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
  37. return future.result()
  38. def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
  39. loop = asyncio.get_event_loop()
  40. lookup_futures = [asyncio.run_coroutine_threadsafe(
  41. self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
  42. current_time = datetime.datetime.now()
  43. experts = [None] * len(uids)
  44. for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
  45. if lookup.result() is not None:
  46. (host, port), timestamp = PickleSerializer.loads(lookup.result())
  47. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  48. experts[i] = RemoteExpert(uid=uid, host=host, port=port)
  49. future.set_result(experts)
  50. def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
  51. """
  52. Make experts available to DHT; update timestamps if already available
  53. :param uids: a list of expert ids to update
  54. :param addr: hostname that can be used to call this expert
  55. :param port: port that can be used to call this expert
  56. :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
  57. """
  58. done_event = mp.Event() if wait_timeout else None
  59. self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event)))
  60. if done_event is not None:
  61. done_event.wait(wait_timeout)
  62. def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
  63. loop = asyncio.get_event_loop()
  64. timestamp = datetime.datetime.now()
  65. expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
  66. prefix_metadata = PickleSerializer.dumps(timestamp)
  67. unique_prefixes = set()
  68. for uid in uids:
  69. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
  70. uid_parts = uid.split(self.UID_DELIMETER)
  71. unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
  72. for prefix in unique_prefixes:
  73. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
  74. if done_event is not None:
  75. done_event.set()
  76. def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
  77. """
  78. Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
  79. :param prefixes: a list of uid prefixes ordered from highest to lowest priority
  80. :param k: return at most *this many* active prefixes
  81. :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
  82. :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
  83. :returns: a list of at most :k: prefixes that have at least one active expert each;
  84. """
  85. future, _future = SharedFuture.make_pair()
  86. self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
  87. max_prefetch=max_prefetch or k, future=_future)))
  88. return future.result()
  89. def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
  90. loop = asyncio.get_event_loop()
  91. lookup_prefetch = [asyncio.run_coroutine_threadsafe(
  92. self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
  93. current_time = datetime.datetime.now()
  94. active_prefixes = []
  95. for i, prefix in enumerate(prefixes):
  96. lookup = lookup_prefetch[i]
  97. if lookup.result() is not None:
  98. timestamp = PickleSerializer.loads(lookup.result())
  99. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  100. active_prefixes.append(prefix)
  101. if len(active_prefixes) >= k:
  102. future.set_result(active_prefixes)
  103. return
  104. # pre-dispatch the next request in line
  105. if len(lookup_prefetch) < len(prefixes):
  106. lookup_prefetch.append(
  107. asyncio.run_coroutine_threadsafe(self.server.get(
  108. self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
  109. )
  110. # could not find enough active prefixes; return what we can
  111. future.set_result(active_prefixes)