Просмотр исходного кода

Replace print() with logger.debug()

Aleksandr Borzunov 2 лет назад
Родитель
Сommit
c8697dc8ed
2 измененных файлов с 7 добавлено и 3 удалено
  1. 1 1
      src/server/backend.py
  2. 6 2
      src/server/handler.py

+ 1 - 1
src/server/backend.py

@@ -61,7 +61,7 @@ class TransformerBackend(ModuleBackend):
                 if not is_dummy(hypo_ids):
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-                print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
+                logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
                 hidden_states, (new_k, new_v) = self.module.forward(
                     hidden_states, layer_past=layer_past, use_cache=True
                 )

+ 6 - 2
src/server/handler.py

@@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
+from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
@@ -25,6 +26,9 @@ from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBas
 from src.utils.misc import DUMMY, is_dummy
 
 
+logger = get_logger(__file__)
+
+
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
@@ -73,7 +77,7 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
-            print("OPENED RPC_INFERENCE")
+            logger.debug("Opened rpc_inference()")
             request = await anext(requests)
             requested_uids = self._check_uids(request.uid)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
@@ -164,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
-            print("CLOSED RPC_INFERENCE")
+            logger.debug("Closed rpc_inference()")
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends