|
@@ -16,12 +16,11 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
+import os
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from functools import partial
|
|
|
-from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
|
|
|
+from typing import Iterable, Optional, Sequence, Union, Callable, Awaitable, TypeVar
|
|
|
|
|
|
-import hivemind
|
|
|
-from hivemind.client import RemoteExpert
|
|
|
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
|
|
|
from hivemind.dht.routing import DHTValue, DHTKey, Subkey
|
|
|
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
|
|
@@ -46,19 +45,22 @@ class DHT(mp.Process):
|
|
|
:param max_workers: declare_experts and get_experts will use up to this many parallel workers
|
|
|
(but no more than one per key)
|
|
|
:param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
|
|
|
+ :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
|
|
|
:param kwargs: any other params will be forwarded to DHTNode upon creation
|
|
|
"""
|
|
|
+ _node: DHTNode
|
|
|
|
|
|
def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
|
|
|
daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
|
|
|
- record_validators: Iterable[RecordValidatorBase] = (), **kwargs):
|
|
|
+ record_validators: Iterable[RecordValidatorBase] = (), shutdown_timeout: float = 3, **kwargs):
|
|
|
super().__init__()
|
|
|
assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
|
|
|
self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
|
|
|
self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
|
|
|
self._record_validator = CompositeValidator(record_validators)
|
|
|
self._port = mp.Value(ctypes.c_int32, 0) # initialized after dht starts
|
|
|
- self._pipe, self.pipe = mp.Pipe(duplex=True)
|
|
|
+ self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
|
|
|
+ self.shutdown_timeout = shutdown_timeout
|
|
|
self.ready = mp.Event()
|
|
|
self.daemon = daemon
|
|
|
if start:
|
|
@@ -70,17 +72,20 @@ class DHT(mp.Process):
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
async def _run():
|
|
|
- node = await DHTNode.create(
|
|
|
+ self._node = await DHTNode.create(
|
|
|
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
|
|
|
num_workers=self.max_workers or 1, record_validator=self._record_validator,
|
|
|
**self.kwargs)
|
|
|
- if node.port is not None:
|
|
|
- self._port.value = node.port
|
|
|
+ if self._node.port is not None:
|
|
|
+ self._port.value = self._node.port
|
|
|
self.ready.set()
|
|
|
|
|
|
while True:
|
|
|
- method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
|
|
|
- asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
|
|
|
+ method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
|
|
|
+ task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
+ if method == '_shutdown':
|
|
|
+ await task
|
|
|
+ break
|
|
|
|
|
|
coro = _run()
|
|
|
loop.run_until_complete(coro)
|
|
@@ -97,10 +102,17 @@ class DHT(mp.Process):
|
|
|
def shutdown(self) -> None:
|
|
|
""" Shut down a running dht process """
|
|
|
if self.is_alive():
|
|
|
- self.terminate()
|
|
|
+ self._outer_pipe.send(('_shutdown', [], {}))
|
|
|
+ self.join(self.shutdown_timeout)
|
|
|
+ if self.is_alive():
|
|
|
+ logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
|
|
|
+ self.terminate()
|
|
|
else:
|
|
|
logger.warning("DHT shutdown has no effect: dht process is already not alive")
|
|
|
|
|
|
+ async def _shutdown(self):
|
|
|
+ await self._node.shutdown()
|
|
|
+
|
|
|
@property
|
|
|
def port(self) -> Optional[int]:
|
|
|
return self._port.value if self._port.value != 0 else None
|
|
@@ -116,12 +128,12 @@ class DHT(mp.Process):
|
|
|
:returns: (value, expiration time); if value was not found, returns None
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
|
|
|
+ self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
|
|
|
return future if return_future else future.result()
|
|
|
|
|
|
- async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
|
|
|
+ async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
|
|
|
try:
|
|
|
- result = await node.get(key, latest=latest, **kwargs)
|
|
|
+ result = await self._node.get(key, latest=latest, **kwargs)
|
|
|
if not future.done():
|
|
|
future.set_result(result)
|
|
|
except BaseException as e:
|
|
@@ -142,14 +154,14 @@ class DHT(mp.Process):
|
|
|
:returns: True if store succeeds, False if it fails (due to no response or newer value)
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
|
|
|
- future=_future, **kwargs)))
|
|
|
+ self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
|
|
|
+ future=_future, **kwargs)))
|
|
|
return future if return_future else future.result()
|
|
|
|
|
|
- async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
|
|
|
+ async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
|
|
|
subkey: Optional[Subkey], future: MPFuture, **kwargs):
|
|
|
try:
|
|
|
- result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
|
|
|
+ result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
|
|
|
if not future.done():
|
|
|
future.set_result(result)
|
|
|
except BaseException as e:
|
|
@@ -173,12 +185,12 @@ class DHT(mp.Process):
|
|
|
:note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
|
|
|
+ self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
|
|
|
return future if return_future else future.result()
|
|
|
|
|
|
- async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
|
|
|
+ async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
|
|
|
future: MPFuture[ReturnType]):
|
|
|
- main_task = asyncio.create_task(coro(self, node))
|
|
|
+ main_task = asyncio.create_task(coro(self, self._node))
|
|
|
cancel_task = asyncio.create_task(await_cancelled(future))
|
|
|
try:
|
|
|
await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
|
|
@@ -215,21 +227,21 @@ class DHT(mp.Process):
|
|
|
assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
|
|
|
assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
|
|
|
+ self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
|
|
|
return future.result()
|
|
|
|
|
|
- async def _get_visible_address(self, node: DHTNode, num_peers: Optional[int], peers: Sequence[Endpoint],
|
|
|
+ async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],
|
|
|
future: Optional[MPFuture]):
|
|
|
- if not peers and (num_peers or not node.protocol.node_info.endpoint):
|
|
|
+ if not peers and (num_peers or not self._node.protocol.node_info.endpoint):
|
|
|
# if we can't resolve the endpoint locally, ask one random peer
|
|
|
- peers_and_endpoints = node.protocol.routing_table.get_nearest_neighbors(
|
|
|
- DHTID.generate(), num_peers or 1, exclude=node.node_id)
|
|
|
+ peers_and_endpoints = self._node.protocol.routing_table.get_nearest_neighbors(
|
|
|
+ DHTID.generate(), num_peers or 1, exclude=self._node.node_id)
|
|
|
peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
|
|
|
|
|
|
chosen_address = None
|
|
|
if peers:
|
|
|
possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
|
|
|
- node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
|
|
|
+ self._node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
|
|
|
|
|
|
for endpoint in possible_endpoints:
|
|
|
if endpoint is None:
|
|
@@ -244,8 +256,8 @@ class DHT(mp.Process):
|
|
|
if chosen_address is None:
|
|
|
logger.warning(f"None of the selected peers responded with an address ({peers})")
|
|
|
|
|
|
- if node.protocol.node_info.endpoint:
|
|
|
- address = strip_port(node.protocol.node_info.endpoint)
|
|
|
+ if self._node.protocol.node_info.endpoint:
|
|
|
+ address = strip_port(self._node.protocol.node_info.endpoint)
|
|
|
if chosen_address is not None and address != chosen_address:
|
|
|
logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
|
|
|
chosen_address = chosen_address or address
|
|
@@ -255,3 +267,7 @@ class DHT(mp.Process):
|
|
|
else:
|
|
|
future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
|
|
|
f" Please ensure the node is connected or specify peers=... manually."))
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ if self._parent_pid == os.getpid() and self.is_alive():
|
|
|
+ self.shutdown()
|