__init__.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import asyncio
  2. import datetime
  3. import multiprocessing as mp
  4. import warnings
  5. from typing import Tuple, List, Optional
  6. from kademlia.network import Server
  7. from tesseract.client import RemoteExpert
  8. from tesseract.utils import run_forever, SharedFuture, PickleSerializer
  9. class TesseractNetwork(mp.Process):
  10. UID_DELIMETER = '.' # splits expert uids over this delimeter
  11. HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
  12. make_key = "{}::{}".format
  13. def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False):
  14. super().__init__()
  15. self.port, self.initial_peers = port, initial_peers
  16. self._pipe, self.pipe = mp.Pipe(duplex=False)
  17. self.server = Server()
  18. if start:
  19. self.start()
  20. def run(self) -> None:
  21. loop = asyncio.new_event_loop()
  22. asyncio.set_event_loop(loop)
  23. loop.run_until_complete(self.server.listen(self.port))
  24. loop.run_until_complete(self.server.bootstrap(self.initial_peers))
  25. run_forever(loop.run_forever)
  26. while True:
  27. method, args, kwargs = self._pipe.recv()
  28. getattr(self, method)(*args, **kwargs)
  29. def shutdown(self) -> None:
  30. """ Shuts down the network process """
  31. if self.is_alive():
  32. self.terminate()
  33. else:
  34. warnings.warn("Network shutdown has no effect: network process is already not alive")
  35. def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
  36. """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
  37. future, _future = SharedFuture.make_pair()
  38. self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
  39. return future.result()
  40. def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
  41. loop = asyncio.get_event_loop()
  42. lookup_futures = [asyncio.run_coroutine_threadsafe(
  43. self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
  44. current_time = datetime.datetime.now()
  45. experts = [None] * len(uids)
  46. for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
  47. if lookup.result() is not None:
  48. (host, port), timestamp = PickleSerializer.loads(lookup.result())
  49. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  50. experts[i] = RemoteExpert(uid=uid, host=host, port=port)
  51. future.set_result(experts)
  52. def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
  53. """
  54. Make experts available to DHT; update timestamps if already available
  55. :param uids: a list of expert ids to update
  56. :param addr: hostname that can be used to call this expert
  57. :param port: port that can be used to call this expert
  58. :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
  59. """
  60. done_event = mp.Event() if wait_timeout else None
  61. self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
  62. if done_event is not None:
  63. done_event.wait(wait_timeout)
  64. def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
  65. loop = asyncio.get_event_loop()
  66. timestamp = datetime.datetime.now()
  67. expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
  68. prefix_metadata = PickleSerializer.dumps(timestamp)
  69. unique_prefixes = set()
  70. for uid in uids:
  71. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
  72. uid_parts = uid.split(self.UID_DELIMETER)
  73. unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
  74. for prefix in unique_prefixes:
  75. asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
  76. if done_event is not None:
  77. done_event.set()
  78. def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
  79. """
  80. Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
  81. :param prefixes: a list of uid prefixes ordered from highest to lowest priority
  82. :param k: return at most *this many* active prefixes
  83. :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
  84. :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
  85. :returns: a list of at most :k: prefixes that have at least one active expert each;
  86. """
  87. future, _future = SharedFuture.make_pair()
  88. self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
  89. max_prefetch=max_prefetch or k, future=_future)))
  90. return future.result()
  91. def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
  92. loop = asyncio.get_event_loop()
  93. lookup_prefetch = [asyncio.run_coroutine_threadsafe(
  94. self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
  95. current_time = datetime.datetime.now()
  96. active_prefixes = []
  97. for i, prefix in enumerate(prefixes):
  98. lookup = lookup_prefetch[i]
  99. if lookup.result() is not None:
  100. timestamp = PickleSerializer.loads(lookup.result())
  101. if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
  102. active_prefixes.append(prefix)
  103. if len(active_prefixes) >= k:
  104. future.set_result(active_prefixes)
  105. return
  106. # pre-dispatch the next request in line
  107. if len(lookup_prefetch) < len(prefixes):
  108. lookup_prefetch.append(
  109. asyncio.run_coroutine_threadsafe(self.server.get(
  110. self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
  111. )
  112. # could not find enough active prefixes; return what we can
  113. future.set_result(active_prefixes)