瀏覽代碼

rework dusty client side, add client side tests

Pavel Samygin 3 年之前
父節點
當前提交
6c5be80cd8
共有 5 個文件被更改,包括 162 次插入65 次删除
  1. 2 0
      src/client/__init__.py
  2. 0 65
      src/client/dust_bank.py
  3. 71 0
      src/client/dusty_block.py
  4. 0 0
      tests/__init__.py
  5. 89 0
      tests/test_dust_payment.py

+ 2 - 0
src/client/__init__.py

@@ -1,3 +1,5 @@
+from src.client.dust_bank import DummyDustBank, DustBankBase
+from src.client.dusty_block import DustyRemoteBlock
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel

+ 0 - 65
src/client/dust_bank.py

@@ -1,7 +1,5 @@
-import inspect
 from abc import ABC, abstractmethod
 from functools import wraps
-from typing import AsyncIterator, Callable, Optional
 
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
@@ -18,66 +16,3 @@ class DustBankBase(ABC):
 class DummyDustBank(DustBankBase):
     def get_dust(self, request: ExpertRequest, method_name: str) -> float:
         return 0.0
-
-
-def _unary_request_wrapper(rpc_call: Callable, rpc_name: str, bank: DustBankBase):
-    @wraps(rpc_call)
-    async def rpc(stub: StubBase, input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
-        meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
-        meta.update("__dust", bank.get_dust(input, rpc_name))
-        input.metadata = MSGPackSerializer.dumps(meta)
-        return await rpc_call(stub, input, timeout)
-
-    return rpc
-
-
-def _stream_request_wrapper(rpc_call: Callable, rpc_name: str, bank: DustBankBase):
-    @wraps(rpc_call)
-    async def rpc(stub: StubBase, input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
-        is_meta_set = False
-
-        def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
-            nonlocal is_meta_set
-            if not is_meta_set:
-                meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
-                meta.update("__dust", bank.get_dust(chunk, rpc_name))
-                chunk.metadata = MSGPackSerializer.dumps(meta)
-                is_meta_set = True
-            return chunk
-
-        return await rpc_call(stub, amap_in_executor(_metadata_setter, input), timeout)
-
-    return rpc
-
-
-def _dustify_handler_stub(stub: StubBase, bank: DustBankBase) -> StubBase:
-    for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
-        if name.startswith("rpc"):
-            spec = inspect.getfullargspec(method)
-            # rpc callers has 3 arguments: stub, input and timeout
-            if len(spec.args) != 3:
-                continue
-
-            input_type = spec.annotations[spec.args[1]]
-
-            if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
-                setattr(stub, name, _stream_request_wrapper(method, name, bank))
-            elif input_type is runtime_pb2.ExpertRequest:
-                setattr(stub, name, _unary_request_wrapper(method, name, bank))
-    return stub
-
-
-def payment_wrapper(bank: DustBankBase) -> Callable:
-    def class_wrapper(cls):
-        d = cls.__dict__
-        if "stub" not in d or not isinstance(d["stub"], property):
-            raise TypeError('wrapped module class supposed to have property "stub"')
-        old_stub = d["stub"]
-
-        def _stub(self):
-            stub = old_stub.__get__(self)
-            return _dustify_handler_stub(stub, bank)
-
-        return type(cls.__name__, cls.__bases__, {k: v if k != "stub" else property(_stub) for k, v in d.items()})
-
-    return class_wrapper

+ 71 - 0
src/client/dusty_block.py

@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+import inspect
+from functools import wraps
+from typing import AsyncIterator, Callable, Optional
+
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.p2p import P2P, StubBase
+from hivemind.proto import runtime_pb2
+from hivemind.utils import MSGPackSerializer, amap_in_executor
+
+from src.client.dust_bank import DustBankBase
+
+
+class DustyRemoteBlock(RemoteExpert):
+    def __init__(self, bank: DustBankBase, expert_info: ExpertInfo, p2p: P2P):
+        self._bank = bank
+        super().__init__(expert_info, p2p)
+
+    def _unary_request_wrapper(self, rpc_call: Callable, rpc_name: str):
+        @wraps(rpc_call)
+        async def rpc(input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
+            meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
+            meta["__dust"] = self._bank.get_dust(input, rpc_name)
+            input.metadata = MSGPackSerializer.dumps(meta)
+            return await rpc_call(input, timeout)
+
+        return rpc
+
+    def _stream_request_wrapper(self, rpc_call: Callable, rpc_name: str):
+        @wraps(rpc_call)
+        async def rpc(input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
+            is_meta_set = False
+
+            def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
+                nonlocal is_meta_set
+                if not is_meta_set:
+                    meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
+                    meta["__dust"] = self._bank.get_dust(chunk, rpc_name)
+                    chunk.metadata = MSGPackSerializer.dumps(meta)
+                    is_meta_set = True
+                return chunk
+
+            return rpc_call(amap_in_executor(_metadata_setter, input), timeout)
+
+        return rpc
+
+    def _dustify_handler_stub(self, stub: StubBase) -> StubBase:
+        for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
+            if name.startswith("rpc"):
+                spec = inspect.getfullargspec(method)
+                # rpc callers has 3 arguments: stub, input and timeout
+                if len(spec.args) != 3:
+                    continue
+
+                input_type = spec.annotations[spec.args[1]]
+
+                if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
+                    setattr(stub, name, self._stream_request_wrapper(method, name))
+                elif input_type is runtime_pb2.ExpertRequest:
+                    setattr(stub, name, self._unary_request_wrapper(method, name))
+        return stub
+
+    @property
+    def _stub(self) -> StubBase:
+        return super().stub
+
+    @property
+    def stub(self) -> StubBase:
+        return self._dustify_handler_stub(self._stub)

+ 0 - 0
tests/__init__.py


+ 89 - 0
tests/test_dust_payment.py

@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from typing import AsyncIterator, Optional
+
+import pytest
+import torch
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.proto.runtime_pb2 import ExpertRequest
+from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
+
+from src.client.dust_bank import DustBankBase
+from src.client.dusty_block import DustyRemoteBlock
+
+
+class DustBankTest(DustBankBase):
+    def __init__(self):
+        self._p = {
+            "rpc_single": 1,
+            "rpc_stream": 2,
+        }
+
+    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
+        return self._p.get(method_name, -1)
+
+
+class HandlerStubTest:
+    async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
+        return input
+
+    async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
+        async for i in input:
+            yield i
+
+    async def rpc_info(self, input: str, timeout: Optional[float] = None):
+        return input
+
+
+class RemoteBlockTest(DustyRemoteBlock):
+    @property
+    def _stub(self):
+        return HandlerStubTest()
+
+
+@pytest.mark.asyncio
+async def test_single():
+    remote = RemoteBlockTest(DustBankTest(), None, None)
+    stub = remote.stub
+    input = torch.randn(1, 2)
+    request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
+
+    out: ExpertRequest = await stub.rpc_single(request)
+
+    assert out.metadata != b""
+    assert len(out.tensors) == 1
+    assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
+
+    meta = MSGPackSerializer.loads(out.metadata)
+    assert isinstance(meta, dict)
+    assert "__dust" in meta
+    assert meta["__dust"] == 1
+
+
+@pytest.mark.asyncio
+async def test_stream():
+    remote = RemoteBlockTest(DustBankTest(), None, None)
+    stub = remote.stub
+    input = torch.randn(2**21, 2)
+
+    split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
+    output_generator = await stub.rpc_stream(
+        amap_in_executor(
+            lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    outputs_list = [part async for part in output_generator]
+    assert len(outputs_list) == 2**5 * 8
+    assert outputs_list[0].metadata != b""
+    for i in range(1, len(outputs_list)):
+        assert outputs_list[i].metadata == b""
+
+    meta = MSGPackSerializer.loads(outputs_list[0].metadata)
+    assert isinstance(meta, dict)
+    assert "__dust" in meta
+    assert meta["__dust"] == 2
+
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
+    assert len(results) == 1
+    assert torch.allclose(results[0], input)