|
@@ -1,5 +1,3 @@
|
|
|
-from __future__ import annotations
|
|
|
-
|
|
|
from typing import AsyncIterator, Optional
|
|
|
|
|
|
import pytest
|
|
@@ -28,8 +26,7 @@ class HandlerStubTest:
|
|
|
return input
|
|
|
|
|
|
async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
|
|
|
- async for i in input:
|
|
|
- yield i
|
|
|
+ return input
|
|
|
|
|
|
async def rpc_info(self, input: str, timeout: Optional[float] = None):
|
|
|
return input
|
|
@@ -48,6 +45,7 @@ async def test_single():
|
|
|
input = torch.randn(1, 2)
|
|
|
request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
|
|
|
|
|
|
+ print(stub)
|
|
|
out: ExpertRequest = await stub.rpc_single(request)
|
|
|
|
|
|
assert out.metadata != b""
|
|
@@ -87,3 +85,12 @@ async def test_stream():
|
|
|
results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
|
|
|
assert len(results) == 1
|
|
|
assert torch.allclose(results[0], input)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_no_wrapper():
|
|
|
+ remote = RemoteBlockTest(DustBankTest(), None, None)
|
|
|
+ stub = remote.stub
|
|
|
+
|
|
|
+ test = await stub.rpc_info("Test")
|
|
|
+ assert test == "Test"
|