|
@@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
@@ -25,6 +26,9 @@ from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBas
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
+logger = get_logger(__file__)
|
|
|
+
|
|
|
+
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
|
|
|
|
@@ -73,7 +77,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
try:
|
|
|
- print("OPENED RPC_INFERENCE")
|
|
|
+ logger.debug("Opened rpc_inference()")
|
|
|
request = await anext(requests)
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
@@ -164,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
prefix_length += hidden_states.shape[1]
|
|
|
request = await (anext(requests))
|
|
|
finally:
|
|
|
- print("CLOSED RPC_INFERENCE")
|
|
|
+ logger.debug("Closed rpc_inference()")
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
# Parse request and prepare backends
|