Pavel Samygin 3 سال پیش
والد
کامیت
bbd6a9b941

+ 2 - 2
benchmarks/benchmark_throughput.py

@@ -8,7 +8,7 @@ import torch
 
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
-from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P, PeerInfo
@@ -47,7 +47,7 @@ def client_process(
     torch.set_num_threads(1)
     can_start.wait()
 
-    p2p = _RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
     peer_info = PeerInfo(server_peer_id, server_maddrs)
     experts = [
         RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)

+ 4 - 4
docs/user/moe.md

@@ -1,7 +1,7 @@
 # Mixture-of-Experts
 
 This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE).
-From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and client-side utilities to access those experts.
+From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and client-side modiles to access those experts.
 
 ## Host experts with a server
 
@@ -101,9 +101,9 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 
 </details>
 
-By default, the server will only accept connections fr om your local network. 
-To enable training over the internet (or some other network), you should set `--host_maddrs` and `--announce_maddrs`.
-These option also allow you to select ipv4 / ipv6 network protocols and tcp / quic transport protocols.
+By default, the server will only accept connections from your local network.
+To enable training over the Internet (or some other network), you should set `--host_maddrs` and `--announce_maddrs`.
+These options also allow you to select IPv4/IPv6 network protocols and TCP and QUIC transport protocols.
 You can find more details in the [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
 
 ## Train the experts

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, create_remote_experts, batch_create_remote_experts, RemoteExpertWorker
 from hivemind.moe.server import (
     ExpertBackend,
     Server,

+ 1 - 0
hivemind/moe/client/__init__.py

@@ -1,3 +1,4 @@
 from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.client.moe import RemoteMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker

+ 7 - 7
hivemind/moe/client/expert.py

@@ -11,7 +11,7 @@ from torch.autograd.function import once_differentiable
 from hivemind import moe
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
-from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
@@ -89,7 +89,7 @@ class RemoteExpert(nn.Module):
     @property
     def info(self):
         if self._rpc_info is None:
-            outputs = _RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._rpc_info
 
@@ -116,9 +116,9 @@ def create_remote_experts(
             p2p = await dht.replicate_p2p()
             return _create_remote_experts(await infos_future, p2p)
 
-        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
-    p2p = _RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
     return _create_remote_experts(infos, p2p)
 
 
@@ -133,7 +133,7 @@ def batch_create_remote_experts(
             p2p = await dht.replicate_p2p()
             return [_create_remote_experts(i, p2p) for i in await infos_future]
 
-        return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
 
     return [create_remote_experts(exps, dht) for exps in infos]
 
@@ -224,7 +224,7 @@ class _RemoteModuleCall(torch.autograd.Function):
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         )
-        deserialized_outputs = _RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
         return tuple(deserialized_outputs)
 
@@ -238,7 +238,7 @@ class _RemoteModuleCall(torch.autograd.Function):
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         )
-        deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
             expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
         )
 

+ 3 - 3
hivemind/moe/client/moe.py

@@ -13,7 +13,7 @@ from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub, expert_backward, expert_forward
-from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
@@ -232,7 +232,7 @@ class _RemoteCallMany(torch.autograd.Function):
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                 )
-                new_task = _RemoteExpertWorker.run_coroutine(
+                new_task = RemoteExpertWorker.run_coroutine(
                     expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
                     return_future=True,
                 )
@@ -327,7 +327,7 @@ class _RemoteCallMany(torch.autograd.Function):
                 serialize_torch_tensor(tensor, proto.compression)
                 for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
             )
-            new_task = _RemoteExpertWorker.run_coroutine(
+            new_task = RemoteExpertWorker.run_coroutine(
                 expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
             )
             pending_tasks[new_task] = (i, j)

+ 1 - 1
hivemind/moe/client/remote_expert_worker.py

@@ -7,7 +7,7 @@ from typing import Awaitable, Optional
 from hivemind.utils import switch_to_uvloop
 
 
-class _RemoteExpertWorker:
+class RemoteExpertWorker:
     """Local thread for managing async tasks related to RemoteExpert"""
 
     _task_queue: Queue = Queue()

+ 6 - 2
hivemind/moe/server/connection_handler.py

@@ -77,6 +77,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
         tensors_stream = amap_in_executor(_unpack, requests)
         inputs = await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
+
+        if expert_uid is None:
+            raise ValueError("empty stream")
+
         return expert_uid, inputs
 
     async def _process_inputs(
@@ -105,7 +109,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         output_split = [
             part
             for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
-            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
 
         async for part in as_aiter(*output_split):
@@ -128,7 +132,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         output_split = [
             part
             for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
-            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
 
         async for part in as_aiter(*output_split):

+ 117 - 77
tests/test_connection_handler.py

@@ -6,7 +6,9 @@ from typing import Any, Dict
 import pytest
 import torch
 
-import hivemind
+import logging
+
+from hivemind.dht import DHT
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.expert_backend import ExpertBackend
@@ -17,43 +19,92 @@ from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 
+LONG_INPUT_SIZE = 2**21
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_connection_handler():
-    handler = ConnectionHandler(
-        hivemind.DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
-    )
+
+class DummyPool(TaskPool):
+    def __init__(self, k: float):
+        self.k = k
+
+    async def submit_task(self, *inputs: torch.Tensor):
+        await asyncio.sleep(0.01)
+        if inputs[0].shape[-1] != 2:
+            raise ValueError("wrong input shape")
+        return [inputs[0] * self.k]
+
+
+class DummyExpertBackend(ExpertBackend):
+    def __init__(self, name: str, k: float):
+        self.name = name
+        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.forward_pool = DummyPool(k)
+        self.backward_pool = DummyPool(k)
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(name=self.name)
+
+
+@pytest.fixture()
+async def stub():
+    server_dht = DHT(start=True)
+    experts = {
+        "expert1": DummyExpertBackend("expert1", k=1),
+        "expert2": DummyExpertBackend("expert2", k=2),
+    }
+    handler = ConnectionHandler(server_dht, experts)
     handler.start()
 
-    client_dht = hivemind.DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
-    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+    client_dht = DHT(start=True, client_mode=True, initial_peers=server_dht.get_visible_maddrs())
+    p2p = await client_dht.replicate_p2p()
+    client_stub = ConnectionHandler.get_stub(p2p, server_dht.peer_id)
+    yield client_stub
+
+    handler.terminate()
+    handler.join()
 
-    inputs = torch.randn(1, 2)
-    inputs_long = torch.randn(2**21, 2)
 
-    # forward unary
-    response = await client_stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
+@pytest.fixture
+def small_input():
+    return torch.randn(1, 2)
+
+
+@pytest.fixture
+def long_input():
+    input = torch.randn(LONG_INPUT_SIZE, 2)
+    n_chunks = (input.nelement() * input.element_size() + DEFAULT_MAX_MSG_SIZE - 1) // DEFAULT_MAX_MSG_SIZE
+    return input, n_chunks
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_forward_unary(stub, small_input):
+    response = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)])
     )
     outputs = deserialize_torch_tensor(response.tensors[0])
     assert len(response.tensors) == 1
-    assert torch.allclose(outputs, inputs * 1)
+    assert torch.allclose(outputs, small_input * 1)
 
-    response = await client_stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
+    response = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(small_input)])
     )
     outputs = deserialize_torch_tensor(response.tensors[0])
     assert len(response.tensors) == 1
-    assert torch.allclose(outputs, inputs * 2)
+    assert torch.allclose(outputs, small_input * 2)
+
 
-    # forward streaming
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_forward_streaming(stub, long_input):
+    input, n_chunks = long_input
     split = (
-        p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
+        p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
     )
-    output_generator = await client_stub.rpc_forward_stream(
+    output_generator = await stub.rpc_forward_stream(
         amap_in_executor(
             lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
             iter_as_aiter(split),
@@ -61,44 +112,53 @@ async def test_connection_handler():
     )
     outputs_list = [part async for part in output_generator]
     del output_generator
-    assert len(outputs_list) == 8  # message size divided by DEFAULT_MAX_MSG_SIZE
+    assert len(outputs_list) == n_chunks
 
     results = await combine_and_deserialize_from_streaming(
         amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
     )
     assert len(results) == 1
-    assert torch.allclose(results[0], inputs_long * 2)
+    assert torch.allclose(results[0], input * 2)
 
-    # forward errors
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_forward_errors(stub, small_input):
+    # no such expert: fails with P2PHandlerError KeyError('expert3')
     with pytest.raises(P2PHandlerError):
-        # no such expert: fails with P2PHandlerError KeyError('expert3')
-        await client_stub.rpc_forward(
-            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
+        await stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(small_input)])
         )
 
+    # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
     with pytest.raises(P2PHandlerError):
-        # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
-        await client_stub.rpc_forward(
+        await stub.rpc_forward(
             runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
         )
 
-    # backward unary
-    response = await client_stub.rpc_backward(
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_backward_unary(stub, small_input):
+    response = await stub.rpc_backward(
         runtime_pb2.ExpertRequest(
-            uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
+            uid="expert2", tensors=[serialize_torch_tensor(small_input * -1), serialize_torch_tensor(small_input)]
         )
     )
     outputs = deserialize_torch_tensor(response.tensors[0])
     assert len(response.tensors) == 1
-    assert torch.allclose(outputs, inputs * -2)
+    assert torch.allclose(outputs, small_input * -2)
 
-    # backward streaming
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_backward_streaming(stub, long_input):
+    input, _ = long_input
     split = (
         p
-        for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
-        for p in split_for_streaming(t, chunk_size_bytes=2**16)
+        for t in [serialize_torch_tensor(input * 3), serialize_torch_tensor(input * 0)]
+        for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
     )
-    output_generator = await client_stub.rpc_backward_stream(
+    output_generator = await stub.rpc_backward_stream(
         amap_in_executor(
             lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
             iter_as_aiter(split),
@@ -108,16 +168,20 @@ async def test_connection_handler():
         amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
     )
     assert len(results) == 1
-    assert torch.allclose(results[0], inputs_long * 3)
+    assert torch.allclose(results[0], input * 3)
 
-    # backward errors
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_backward_errors(stub, small_input, long_input):
+    long, _ = long_input
+    # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
     with pytest.raises(P2PHandlerError):
-        # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
-        await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
+        await stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
 
+    # backward fails: empty stream
     with pytest.raises(P2PHandlerError):
-        # backward fails: empty stream
-        output_generator = await client_stub.rpc_backward_stream(
+        output_generator = await stub.rpc_backward_stream(
             amap_in_executor(
                 lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
                 iter_as_aiter([]),
@@ -127,44 +191,20 @@ async def test_connection_handler():
             amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
         )
         assert len(results) == 1
-        assert torch.allclose(results[0], inputs_long * 3)
+        assert torch.allclose(results[0], long * 3)
 
     # check that handler did not crash after failed request
-    await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
+    await stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)]))
+
 
-    # info
-    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_info(stub):
+    response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
     assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
 
-    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
+    response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
     assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
 
     with pytest.raises(P2PHandlerError):
-        await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
-
-    handler.terminate()
-    handler.join()
-
-
-class DummyPool(TaskPool):
-    def __init__(self, k: float):
-        self.k = k
-
-    async def submit_task(self, *inputs: torch.Tensor):
-        await asyncio.sleep(0.01)
-        print(type(inputs), inputs[0].shape)
-        assert inputs[0].shape[-1] == 2
-        return [inputs[0] * self.k]
-
-
-class DummyExpertBackend(ExpertBackend):
-    def __init__(self, name: str, k: float):
-        self.name = name
-        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
-        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
-        self.forward_pool = DummyPool(k)
-        self.backward_pool = DummyPool(k)
-
-    def get_info(self) -> Dict[str, Any]:
-        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
-        return dict(name=self.name)
+        await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))

+ 3 - 3
tests/test_moe.py

@@ -272,10 +272,10 @@ def test_client_anomaly_detection():
     server.start()
     try:
         server.ready.wait()
-        dht_experts = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
+        client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht_experts, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
         input = torch.randn(1, 16)
@@ -292,7 +292,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht_experts, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)
         assert output.isfinite().all()

+ 1 - 1
tests/test_p2p_daemon_bindings.py

@@ -572,7 +572,7 @@ async def test_client_stream_handler_success(p2pcs):
 
     # add in balanced mode: handler should be placed in round robin queue
     # and become the next to be called
-    await p2pcs[1].stream_handler(another_proto, handler_third, True)
+    await p2pcs[1].stream_handler(another_proto, handler_third, balanced=True)
     assert another_proto in p2pcs[1].control.handlers
     # ensure the handler is override
     assert handler_third == p2pcs[1].control.handlers[another_proto]