123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- import asyncio
- import multiprocessing as mp
- import subprocess
- from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
- import numpy as np
- import pytest
- from hivemind.p2p import P2P
- from hivemind.proto import dht_pb2
- RUNNING = 'running'
- NOT_RUNNING = 'not running'
- CHECK_PID_CMD = '''
- if ps -p {0} > /dev/null;
- then
- echo "{1}"
- else
- echo "{2}"
- fi
- '''
- def is_process_running(pid: int) -> bool:
- cmd = CHECK_PID_CMD.format(pid, RUNNING, NOT_RUNNING)
- return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
- @pytest.mark.asyncio
- async def test_daemon_killed_on_del():
- p2p_daemon = await P2P.create()
- child_pid = p2p_daemon._child.pid
- assert is_process_running(child_pid)
- p2p_daemon.__del__()
- assert not is_process_running(child_pid)
- def handle_square(x):
- return x ** 2
- def handle_add(args):
- result = args[0]
- for i in range(1, len(args)):
- result = result + args[i]
- return result
- @pytest.mark.parametrize(
- 'should_cancel', [True, False]
- )
- @pytest.mark.asyncio
- async def test_call_unary_handler(should_cancel, handle_name="handle"):
- handler_cancelled = False
- async def ping_handler(request, context):
- try:
- await asyncio.sleep(2)
- except asyncio.CancelledError:
- nonlocal handler_cancelled
- handler_cancelled = True
- return dht_pb2.PingResponse(
- peer=dht_pb2.NodeInfo(
- node_id=context.ours_id.encode(), rpc_port=context.ours_port),
- sender_endpoint=context.handle_name, available=True)
- server = await P2P.create()
- server_pid = server._child.pid
- await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
- dht_pb2.PingResponse)
- assert is_process_running(server_pid)
- client = await P2P.create()
- client_pid = client._child.pid
- assert is_process_running(client_pid)
- ping_request = dht_pb2.PingRequest(
- peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
- validate=True)
- expected_response = dht_pb2.PingResponse(
- peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
- sender_endpoint=handle_name, available=True)
- await asyncio.sleep(1)
- libp2p_server_id = ID.from_base58(server.id)
- stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
- await P2P.send_raw_data(ping_request.SerializeToString(), writer)
- if should_cancel:
- writer.close()
- await asyncio.sleep(1)
- assert handler_cancelled
- else:
- result = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
- assert result == expected_response
- assert not handler_cancelled
- await server.stop_listening()
- server.__del__()
- assert not is_process_running(server_pid)
- client.__del__()
- assert not is_process_running(client_pid)
- @pytest.mark.parametrize(
- "test_input,handle",
- [
- pytest.param(10, handle_square, id="square_integer"),
- pytest.param((1, 2), handle_add, id="add_integers"),
- pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
- pytest.param(2, lambda x: x ** 3, id="lambda")
- ]
- )
- @pytest.mark.asyncio
- async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
- server = await P2P.create()
- server_pid = server._child.pid
- await server.add_stream_handler(handler_name, handle)
- assert is_process_running(server_pid)
- client = await P2P.create()
- client_pid = client._child.pid
- assert is_process_running(client_pid)
- await asyncio.sleep(1)
- result = await client.call_peer_handler(server.id, handler_name, test_input)
- assert result == handle(test_input)
- await server.stop_listening()
- server.__del__()
- assert not is_process_running(server_pid)
- client.__del__()
- assert not is_process_running(client_pid)
- async def run_server(handler_name, server_side, client_side, response_received):
- server = await P2P.create()
- server_pid = server._child.pid
- await server.add_stream_handler(handler_name, handle_square)
- assert is_process_running(server_pid)
- server_side.send(server.id)
- while response_received.value == 0:
- await asyncio.sleep(0.5)
- await server.stop_listening()
- server.__del__()
- assert not is_process_running(server_pid)
- def server_target(handler_name, server_side, client_side, response_received):
- asyncio.run(run_server(handler_name, server_side, client_side, response_received))
- @pytest.mark.asyncio
- async def test_call_peer_different_processes():
- handler_name = "square"
- test_input = np.random.randn(2, 3)
- server_side, client_side = mp.Pipe()
- response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
- response_received.value = 0
- proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
- proc.start()
- client = await P2P.create()
- client_pid = client._child.pid
- assert is_process_running(client_pid)
- await asyncio.sleep(1)
- peer_id = client_side.recv()
- result = await client.call_peer_handler(peer_id, handler_name, test_input)
- assert np.allclose(result, handle_square(test_input))
- response_received.value = 1
- client.__del__()
- assert not is_process_running(client_pid)
- proc.join()
- @pytest.mark.parametrize(
- "test_input,handle",
- [
- pytest.param(np.random.randn(2, 3), handle_square, id="square"),
- pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
- ]
- )
- @pytest.mark.asyncio
- async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
- server = await P2P.create()
- await server.add_stream_handler(handler_name, handle)
- client = await P2P.create()
- await asyncio.sleep(1)
- result = await client.call_peer_handler(server.id, handler_name, test_input)
- assert np.allclose(result, handle(test_input))
- @pytest.mark.asyncio
- async def test_call_peer_error(handler_name="handle"):
- server = await P2P.create()
- await server.add_stream_handler(handler_name, handle_add)
- client = await P2P.create()
- await asyncio.sleep(1)
- result = await client.call_peer_handler(server.id, handler_name,
- [np.zeros((2, 3)), np.zeros((3, 2))])
- assert type(result) == ValueError
|