test_connection_handler.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. import asyncio
  3. import math
  4. from typing import Any, Dict
  5. import pytest
  6. import torch
  7. from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
  8. from hivemind.dht import DHT
  9. from hivemind.moe.server.connection_handler import ConnectionHandler
  10. from hivemind.moe.server.module_backend import ModuleBackend
  11. from hivemind.moe.server.task_pool import TaskPool
  12. from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
  13. from hivemind.proto import runtime_pb2
  14. from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
  15. from hivemind.utils.serializer import MSGPackSerializer
  16. from hivemind.utils.streaming import split_for_streaming
  17. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  18. @pytest.mark.forked
  19. @pytest.mark.asyncio
  20. async def test_connection_handler_info():
  21. handler = ConnectionHandler(
  22. DHT(start=True),
  23. dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
  24. )
  25. handler.start()
  26. client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
  27. client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
  28. # info
  29. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
  30. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
  31. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
  32. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
  33. with pytest.raises(P2PHandlerError):
  34. await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
  35. @pytest.mark.forked
  36. @pytest.mark.asyncio
  37. async def test_connection_handler_forward():
  38. handler = ConnectionHandler(
  39. DHT(start=True),
  40. dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
  41. )
  42. handler.start()
  43. client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
  44. client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
  45. inputs = torch.randn(1, 2)
  46. inputs_long = torch.randn(2**21, 2)
  47. # forward unary
  48. response = await client_stub.rpc_forward(
  49. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
  50. )
  51. outputs = deserialize_torch_tensor(response.tensors[0])
  52. assert len(response.tensors) == 1
  53. assert torch.allclose(outputs, inputs * 1)
  54. response = await client_stub.rpc_forward(
  55. runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
  56. )
  57. outputs = deserialize_torch_tensor(response.tensors[0])
  58. assert len(response.tensors) == 1
  59. assert torch.allclose(outputs, inputs * 2)
  60. # forward streaming
  61. split = (
  62. p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
  63. )
  64. output_generator = await client_stub.rpc_forward_stream(
  65. amap_in_executor(
  66. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  67. iter_as_aiter(split),
  68. ),
  69. )
  70. outputs_list = [part async for part in output_generator]
  71. assert len(outputs_list) == math.ceil(inputs_long.numel() * 4 / DEFAULT_MAX_MSG_SIZE)
  72. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
  73. assert len(results) == 1
  74. assert torch.allclose(results[0], inputs_long * 2)
  75. # forward errors
  76. with pytest.raises(P2PHandlerError):
  77. # no such expert: fails with P2PHandlerError KeyError('expert3')
  78. await client_stub.rpc_forward(
  79. runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
  80. )
  81. with pytest.raises(P2PHandlerError):
  82. # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
  83. await client_stub.rpc_forward(
  84. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
  85. )
  86. @pytest.mark.forked
  87. @pytest.mark.asyncio
  88. async def test_connection_handler_backward():
  89. handler = ConnectionHandler(
  90. DHT(start=True),
  91. dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
  92. )
  93. handler.start()
  94. client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
  95. client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
  96. inputs = torch.randn(1, 2)
  97. inputs_long = torch.randn(2**21, 2)
  98. # backward unary
  99. response = await client_stub.rpc_backward(
  100. runtime_pb2.ExpertRequest(
  101. uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
  102. )
  103. )
  104. outputs = deserialize_torch_tensor(response.tensors[0])
  105. assert len(response.tensors) == 1
  106. assert torch.allclose(outputs, inputs * -2)
  107. # backward streaming
  108. split = (
  109. p
  110. for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
  111. for p in split_for_streaming(t, chunk_size_bytes=2**16)
  112. )
  113. output_generator = await client_stub.rpc_backward_stream(
  114. amap_in_executor(
  115. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
  116. iter_as_aiter(split),
  117. ),
  118. )
  119. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
  120. assert len(results) == 1
  121. assert torch.allclose(results[0], inputs_long * 3)
  122. # backward errors
  123. with pytest.raises(P2PHandlerError):
  124. # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
  125. await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
  126. with pytest.raises(P2PHandlerError):
  127. # backward fails: empty stream
  128. output_generator = await client_stub.rpc_backward_stream(
  129. amap_in_executor(
  130. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  131. iter_as_aiter([]),
  132. ),
  133. )
  134. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
  135. assert len(results) == 1
  136. assert torch.allclose(results[0], inputs_long * 3)
  137. # check that handler did not crash after failed request
  138. await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
  139. handler.terminate()
  140. handler.join()
  141. class DummyPool(TaskPool):
  142. def __init__(self, k: float):
  143. self.k = k
  144. async def submit_task(self, *inputs: torch.Tensor):
  145. await asyncio.sleep(0.01)
  146. assert inputs[0].shape[-1] == 2
  147. return [inputs[0] * self.k]
  148. class DummyModuleBackend(ModuleBackend):
  149. def __init__(self, name: str, k: float):
  150. self.name = name
  151. self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  152. self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  153. self.forward_pool = DummyPool(k)
  154. self.backward_pool = DummyPool(k)
  155. def get_info(self) -> Dict[str, Any]:
  156. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  157. return dict(name=self.name)