test_connection_handler.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import annotations
  2. import asyncio
  3. from typing import Any, Dict
  4. import pytest
  5. import torch
  6. import hivemind
  7. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  8. from hivemind.moe.server.connection_handler import ConnectionHandler
  9. from hivemind.moe.server.expert_backend import ExpertBackend
  10. from hivemind.moe.server.task_pool import TaskPool
  11. from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
  12. from hivemind.proto import runtime_pb2
  13. from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
  14. from hivemind.utils.serializer import MSGPackSerializer
  15. from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
  16. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  17. @pytest.mark.forked
  18. @pytest.mark.asyncio
  19. async def test_connection_handler():
  20. handler = ConnectionHandler(
  21. hivemind.DHT(start=True),
  22. dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
  23. )
  24. handler.start()
  25. client_dht = hivemind.DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
  26. client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
  27. inputs = torch.randn(1, 2)
  28. inputs_long = torch.randn(2**21, 2)
  29. # forward unary
  30. response = await client_stub.rpc_forward(
  31. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
  32. )
  33. outputs = deserialize_torch_tensor(response.tensors[0])
  34. assert len(response.tensors) == 1
  35. assert torch.allclose(outputs, inputs * 1)
  36. response = await client_stub.rpc_forward(
  37. runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
  38. )
  39. outputs = deserialize_torch_tensor(response.tensors[0])
  40. assert len(response.tensors) == 1
  41. assert torch.allclose(outputs, inputs * 2)
  42. # forward streaming
  43. split = (
  44. p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
  45. )
  46. output_generator = await client_stub.rpc_forward_stream(
  47. amap_in_executor(
  48. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  49. iter_as_aiter(split),
  50. ),
  51. )
  52. outputs_list = [part async for part in output_generator]
  53. del output_generator
  54. assert len(outputs_list) == 8 # message size divided by DEFAULT_MAX_MSG_SIZE
  55. results = await combine_and_deserialize_from_streaming(
  56. amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
  57. )
  58. assert len(results) == 1
  59. assert torch.allclose(results[0], inputs_long * 2)
  60. # forward errors
  61. with pytest.raises(P2PHandlerError):
  62. # no such expert: fails with P2PHandlerError KeyError('expert3')
  63. await client_stub.rpc_forward(
  64. runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
  65. )
  66. with pytest.raises(P2PHandlerError):
  67. # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
  68. await client_stub.rpc_forward(
  69. runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
  70. )
  71. # backward unary
  72. response = await client_stub.rpc_backward(
  73. runtime_pb2.ExpertRequest(
  74. uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
  75. )
  76. )
  77. outputs = deserialize_torch_tensor(response.tensors[0])
  78. assert len(response.tensors) == 1
  79. assert torch.allclose(outputs, inputs * -2)
  80. # backward streaming
  81. split = (
  82. p
  83. for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
  84. for p in split_for_streaming(t, chunk_size_bytes=2**16)
  85. )
  86. output_generator = await client_stub.rpc_backward_stream(
  87. amap_in_executor(
  88. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
  89. iter_as_aiter(split),
  90. ),
  91. )
  92. results = await combine_and_deserialize_from_streaming(
  93. amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
  94. )
  95. assert len(results) == 1
  96. assert torch.allclose(results[0], inputs_long * 3)
  97. # backward errors
  98. with pytest.raises(P2PHandlerError):
  99. # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
  100. await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
  101. with pytest.raises(P2PHandlerError):
  102. # backward fails: empty stream
  103. output_generator = await client_stub.rpc_backward_stream(
  104. amap_in_executor(
  105. lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
  106. iter_as_aiter([]),
  107. ),
  108. )
  109. results = await combine_and_deserialize_from_streaming(
  110. amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
  111. )
  112. assert len(results) == 1
  113. assert torch.allclose(results[0], inputs_long * 3)
  114. # check that handler did not crash after failed request
  115. await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
  116. # info
  117. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
  118. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
  119. response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
  120. assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
  121. with pytest.raises(P2PHandlerError):
  122. await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
  123. handler.terminate()
  124. handler.join()
  125. class DummyPool(TaskPool):
  126. def __init__(self, k: float):
  127. self.k = k
  128. async def submit_task(self, *inputs: torch.Tensor):
  129. await asyncio.sleep(0.01)
  130. print(type(inputs), inputs[0].shape)
  131. assert inputs[0].shape[-1] == 2
  132. return [inputs[0] * self.k]
  133. class DummyExpertBackend(ExpertBackend):
  134. def __init__(self, name: str, k: float):
  135. self.name = name
  136. self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  137. self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
  138. self.forward_pool = DummyPool(k)
  139. self.backward_pool = DummyPool(k)
  140. def get_info(self) -> Dict[str, Any]:
  141. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  142. return dict(name=self.name)