|
@@ -11,13 +11,17 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
from hivemind.p2p import P2P, StubBase
|
|
from hivemind.p2p import P2P, StubBase
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.proto import runtime_pb2
|
|
-from hivemind.utils import anext, nested_flatten
|
|
|
|
|
|
+from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
|
|
|
|
|
|
from src.data_structures import RemoteModuleInfo
|
|
from src.data_structures import RemoteModuleInfo
|
|
from src.dht_utils import ModuleUID
|
|
from src.dht_utils import ModuleUID
|
|
from src.server.handler import TransformerConnectionHandler
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
|
|
|
|
|
|
+use_hivemind_log_handler("in_root_logger")
|
|
|
|
+logger = get_logger(__file__)
|
|
|
|
+
|
|
|
|
+
|
|
class RemoteTransformerBlock(RemoteExpert):
|
|
class RemoteTransformerBlock(RemoteExpert):
|
|
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
|
|
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
|
|
|
|
|
|
@@ -29,11 +33,20 @@ class RemoteTransformerBlock(RemoteExpert):
|
|
def stub(self) -> StubBase:
|
|
def stub(self) -> StubBase:
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
|
|
- def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
|
|
|
|
|
+ def forward(self, inputs: torch.Tensor, **kwargs):
|
|
|
|
+ for k, v in kwargs.items():
|
|
|
|
+ assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
|
|
|
|
+ return super().forward(inputs)
|
|
|
|
+
|
|
|
|
+ def inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
|
"""Initialize a new inference session with the specified remote server"""
|
|
"""Initialize a new inference session with the specified remote server"""
|
|
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
|
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
|
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
|
|
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
|
|
|
|
|
|
|
|
+ def begin_inference_session(self):
|
|
|
|
+ logger.warning("beging_inference_session was renamed to just inference_session")
|
|
|
|
+ return self.inference_session()
|
|
|
|
+
|
|
|
|
|
|
class RemoteTransformerBlockInferenceSession:
|
|
class RemoteTransformerBlockInferenceSession:
|
|
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
|
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
|
@@ -44,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
+ self.stepped = False
|
|
self.closed = False
|
|
self.closed = False
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
@@ -89,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
+ self.stepped = True
|
|
return await anext(self._outputs_stream)
|
|
return await anext(self._outputs_stream)
|
|
|
|
|
|
def close(self):
|
|
def close(self):
|
|
@@ -103,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
|
|
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
|
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
|
if self._outputs_stream is None:
|
|
if self._outputs_stream is None:
|
|
return # already closed
|
|
return # already closed
|
|
- await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
|
|
|
- try:
|
|
|
|
- await anext(self._outputs_stream)
|
|
|
|
- except StopAsyncIteration:
|
|
|
|
- pass
|
|
|
|
|
|
+ if self.stepped:
|
|
|
|
+ await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
|
|
|
+ try:
|
|
|
|
+ await anext(self._outputs_stream)
|
|
|
|
+ except StopAsyncIteration:
|
|
|
|
+ pass
|
|
|
|
|
|
def __del__(self):
|
|
def __del__(self):
|
|
self.close()
|
|
self.close()
|