test_connection_handler.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from __future__ import annotations
  2. import asyncio
  3. from typing import Any, Dict
  4. import pytest
  5. import torch
  6. import logging
  7. from hivemind.dht import DHT
  8. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  9. from hivemind.moe.server.connection_handler import ConnectionHandler
  10. from hivemind.moe.server.expert_backend import ExpertBackend
  11. from hivemind.moe.server.task_pool import TaskPool
  12. from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
  13. from hivemind.proto import runtime_pb2
  14. from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
  15. from hivemind.utils.serializer import MSGPackSerializer
  16. from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
  17. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  18. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  19. LONG_INPUT_SIZE = 2**21
  20. class DummyPool(TaskPool):
  21. def __init__(self, k: float):
  22. self.k = k
  23. async def submit_task(self, *inputs: torch.Tensor):
  24. await asyncio.sleep(0.01)
  25. if inputs[0].shape[-1] != 2:
  26. raise ValueError("wrong input shape")
  27. return [inputs[0] * self.k]
  28. class DummyExpertBackend(ExpertBackend):
  29. def __init__(self, name: str, k: float):
  30. self.name = name
  31. self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  32. self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  33. self.forward_pool = DummyPool(k)
  34. self.backward_pool = DummyPool(k)
  35. def get_info(self) -> Dict[str, Any]:
  36. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  37. return dict(name=self.name)
  38. @pytest.fixture()
  39. async def stub():
  40. server_dht = DHT(start=True)
  41. experts = {
  42. "expert1": DummyExpertBackend("expert1", k=1),
  43. "expert2": DummyExpertBackend("expert2", k=2),
  44. }
  45. handler = ConnectionHandler(server_dht, experts)
  46. handler.start()
  47. client_dht = DHT(start=True, client_mode=True, initial_peers=server_dht.get_visible_maddrs())
  48. p2p = await client_dht.replicate_p2p()
  49. client_stub = ConnectionHandler.get_stub(p2p, server_dht.peer_id)
  50. yield client_stub
  51. handler.terminate()
  52. handler.join()
  53. @pytest.fixture
  54. def small_input():
  55. return torch.randn(1, 2)
  56. @pytest.fixture
  57. def long_input():
  58. input = torch.randn(LONG_INPUT_SIZE, 2)
  59. n_chunks = (input.nelement() * input.element_size() + DEFAULT_MAX_MSG_SIZE - 1) // DEFAULT_MAX_MSG_SIZE
  60. return input, n_chunks
  61. @pytest.mark.forked
  62. @pytest.mark.asyncio
  63. async def test_forward_unary(stub, small_input):
  64. response = await stub.rpc_forward(
  65. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)])
  66. )
  67. outputs = deserialize_torch_tensor(response.tensors[0])
  68. assert len(response.tensors) == 1
  69. assert torch.allclose(outputs, small_input * 1)
  70. response = await stub.rpc_forward(
  71. runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(small_input)])
  72. )
  73. outputs = deserialize_torch_tensor(response.tensors[0])
  74. assert len(response.tensors) == 1
  75. assert torch.allclose(outputs, small_input * 2)
  76. @pytest.mark.forked
  77. @pytest.mark.asyncio
  78. async def test_forward_streaming(stub, long_input):
  79. input, n_chunks = long_input
  80. split = (
  81. p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
  82. )
  83. output_generator = await stub.rpc_forward_stream(
  84. amap_in_executor(
  85. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  86. iter_as_aiter(split),
  87. ),
  88. )
  89. outputs_list = [part async for part in output_generator]
  90. del output_generator
  91. assert len(outputs_list) == n_chunks
  92. results = await combine_and_deserialize_from_streaming(
  93. amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
  94. )
  95. assert len(results) == 1
  96. assert torch.allclose(results[0], input * 2)
  97. @pytest.mark.forked
  98. @pytest.mark.asyncio
  99. async def test_forward_errors(stub, small_input):
  100. # no such expert: fails with P2PHandlerError KeyError('expert3')
  101. with pytest.raises(P2PHandlerError):
  102. await stub.rpc_forward(
  103. runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(small_input)])
  104. )
  105. # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
  106. with pytest.raises(P2PHandlerError):
  107. await stub.rpc_forward(
  108. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
  109. )
  110. @pytest.mark.forked
  111. @pytest.mark.asyncio
  112. async def test_backward_unary(stub, small_input):
  113. response = await stub.rpc_backward(
  114. runtime_pb2.ExpertRequest(
  115. uid="expert2", tensors=[serialize_torch_tensor(small_input * -1), serialize_torch_tensor(small_input)]
  116. )
  117. )
  118. outputs = deserialize_torch_tensor(response.tensors[0])
  119. assert len(response.tensors) == 1
  120. assert torch.allclose(outputs, small_input * -2)
  121. @pytest.mark.forked
  122. @pytest.mark.asyncio
  123. async def test_backward_streaming(stub, long_input):
  124. input, _ = long_input
  125. split = (
  126. p
  127. for t in [serialize_torch_tensor(input * 3), serialize_torch_tensor(input * 0)]
  128. for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
  129. )
  130. output_generator = await stub.rpc_backward_stream(
  131. amap_in_executor(
  132. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
  133. iter_as_aiter(split),
  134. ),
  135. )
  136. results = await combine_and_deserialize_from_streaming(
  137. amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
  138. )
  139. assert len(results) == 1
  140. assert torch.allclose(results[0], input * 3)
  141. @pytest.mark.forked
  142. @pytest.mark.asyncio
  143. async def test_backward_errors(stub, small_input, long_input):
  144. long, _ = long_input
  145. # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
  146. with pytest.raises(P2PHandlerError):
  147. await stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
  148. # backward fails: empty stream
  149. with pytest.raises(P2PHandlerError):
  150. output_generator = await stub.rpc_backward_stream(
  151. amap_in_executor(
  152. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  153. iter_as_aiter([]),
  154. ),
  155. )
  156. results = await combine_and_deserialize_from_streaming(
  157. amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
  158. )
  159. assert len(results) == 1
  160. assert torch.allclose(results[0], long * 3)
  161. # check that handler did not crash after failed request
  162. await stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)]))
  163. @pytest.mark.forked
  164. @pytest.mark.asyncio
  165. async def test_info(stub):
  166. response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
  167. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
  168. response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
  169. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
  170. with pytest.raises(P2PHandlerError):
  171. await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))