|
@@ -13,7 +13,7 @@ from hivemind.p2p import P2P, StubBase
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
|
|
|
|
|
|
-from src.data_structures import RemoteModuleInfo
|
|
|
+from src.data_structures import RemoteModuleInfo, RPCInfo
|
|
|
from src.dht_utils import ModuleUID
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
@@ -50,8 +50,8 @@ class RemoteTransformerBlock(RemoteExpert):
|
|
|
class RemoteTransformerBlockInferenceSession:
|
|
|
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
|
|
|
|
|
- def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
|
|
|
- self.uid, self.info = uid, info
|
|
|
+ def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
|
|
|
+ self.uid, self.rpc_info = uid, rpc_info
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
@@ -61,14 +61,14 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
|
@classmethod
|
|
|
async def _create(
|
|
|
- cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
|
|
|
+ cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
|
|
|
) -> RemoteTransformerBlockInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
- outputs_stream = await remote_module.stub.rpc_inference(
|
|
|
+ outputs_stream = await stub.rpc_inference(
|
|
|
cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
|
|
|
)
|
|
|
- return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
|
|
|
+ return cls(uid, rpc_info, inputs_queue, outputs_stream)
|
|
|
|
|
|
@staticmethod
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
|
|
@@ -90,7 +90,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
uid=self.uid,
|
|
|
tensors=[
|
|
|
serialize_torch_tensor(tensor, proto.compression)
|
|
|
- for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
|
|
|
],
|
|
|
)
|
|
|
)
|