test_connection_handler.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. import asyncio
  3. import math
  4. from typing import Any, Dict
  5. import pytest
  6. import torch
  7. from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
  8. from hivemind.dht import DHT
  9. from hivemind.moe.server.connection_handler import ConnectionHandler
  10. from hivemind.moe.server.module_backend import ModuleBackend
  11. from hivemind.moe.server.task_pool import TaskPool
  12. from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
  13. from hivemind.proto import runtime_pb2
  14. from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
  15. from hivemind.utils.serializer import MSGPackSerializer
  16. from hivemind.utils.streaming import split_for_streaming
  17. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  18. @pytest.fixture
  19. async def client_stub():
  20. handler_dht = DHT(start=True)
  21. module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
  22. handler = ConnectionHandler(handler_dht, module_backends, start=True)
  23. client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
  24. client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
  25. yield client_stub
  26. client_dht.shutdown()
  27. handler.shutdown()
  28. handler_dht.shutdown()
  29. @pytest.mark.forked
  30. @pytest.mark.asyncio
  31. async def test_connection_handler_info(client_stub):
  32. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
  33. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
  34. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
  35. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
  36. with pytest.raises(P2PHandlerError):
  37. await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
  38. @pytest.mark.forked
  39. @pytest.mark.asyncio
  40. async def test_connection_handler_forward(client_stub):
  41. inputs = torch.randn(1, 2)
  42. inputs_long = torch.randn(2**21, 2)
  43. # forward unary
  44. response = await client_stub.rpc_forward(
  45. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
  46. )
  47. outputs = deserialize_torch_tensor(response.tensors[0])
  48. assert len(response.tensors) == 1
  49. assert torch.allclose(outputs, inputs * 1)
  50. response = await client_stub.rpc_forward(
  51. runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
  52. )
  53. outputs = deserialize_torch_tensor(response.tensors[0])
  54. assert len(response.tensors) == 1
  55. assert torch.allclose(outputs, inputs * 2)
  56. # forward streaming
  57. split = (
  58. p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
  59. )
  60. output_generator = await client_stub.rpc_forward_stream(
  61. amap_in_executor(
  62. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  63. iter_as_aiter(split),
  64. ),
  65. )
  66. outputs_list = [part async for part in output_generator]
  67. assert len(outputs_list) == math.ceil(inputs_long.numel() * 4 / DEFAULT_MAX_MSG_SIZE)
  68. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
  69. assert len(results) == 1
  70. assert torch.allclose(results[0], inputs_long * 2)
  71. # forward errors
  72. with pytest.raises(P2PHandlerError):
  73. # no such expert: fails with P2PHandlerError KeyError('expert3')
  74. await client_stub.rpc_forward(
  75. runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
  76. )
  77. with pytest.raises(P2PHandlerError):
  78. # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
  79. await client_stub.rpc_forward(
  80. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
  81. )
  82. @pytest.mark.forked
  83. @pytest.mark.asyncio
  84. async def test_connection_handler_backward(client_stub):
  85. inputs = torch.randn(1, 2)
  86. inputs_long = torch.randn(2**21, 2)
  87. # backward unary
  88. response = await client_stub.rpc_backward(
  89. runtime_pb2.ExpertRequest(
  90. uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
  91. )
  92. )
  93. outputs = deserialize_torch_tensor(response.tensors[0])
  94. assert len(response.tensors) == 1
  95. assert torch.allclose(outputs, inputs * -2)
  96. # backward streaming
  97. split = (
  98. p
  99. for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
  100. for p in split_for_streaming(t, chunk_size_bytes=2**16)
  101. )
  102. output_generator = await client_stub.rpc_backward_stream(
  103. amap_in_executor(
  104. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
  105. iter_as_aiter(split),
  106. ),
  107. )
  108. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
  109. assert len(results) == 1
  110. assert torch.allclose(results[0], inputs_long * 3)
  111. # backward errors
  112. with pytest.raises(P2PHandlerError):
  113. # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
  114. await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
  115. with pytest.raises(P2PHandlerError):
  116. # backward fails: empty stream
  117. output_generator = await client_stub.rpc_backward_stream(
  118. amap_in_executor(
  119. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  120. iter_as_aiter([]),
  121. ),
  122. )
  123. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
  124. assert len(results) == 1
  125. assert torch.allclose(results[0], inputs_long * 3)
  126. # check that handler did not crash after failed request
  127. await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
  128. @pytest.mark.forked
  129. @pytest.mark.asyncio
  130. async def test_connection_handler_shutdown():
  131. # Here, all handlers will have the common hivemind.DHT and hivemind.P2P instances
  132. handler_dht = DHT(start=True)
  133. module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
  134. for _ in range(3):
  135. handler = ConnectionHandler(handler_dht, module_backends, balanced=False, start=True)
  136. # The line above would raise an exception if the previous handlers were not removed from hivemind.P2P
  137. handler.shutdown()
  138. handler_dht.shutdown()
  139. class DummyPool(TaskPool):
  140. def __init__(self, k: float):
  141. self.k = k
  142. async def submit_task(self, *inputs: torch.Tensor):
  143. await asyncio.sleep(0.01)
  144. assert inputs[0].shape[-1] == 2
  145. return [inputs[0] * self.k]
  146. class DummyModuleBackend(ModuleBackend):
  147. def __init__(self, name: str, k: float):
  148. self.name = name
  149. self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  150. self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  151. self.forward_pool = DummyPool(k)
  152. self.backward_pool = DummyPool(k)
  153. def get_info(self) -> Dict[str, Any]:
  154. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  155. return dict(name=self.name)