Parcourir la source

Improve server's logging (#96)

Log all RPC calls with block indices and shortened peer IDs, print attention cache stats.
Alexander Borzunov il y a 2 ans
Parent
commit
d8ef09146e
3 fichiers modifiés avec 38 ajouts et 15 suppressions
  1. 1 1
      requirements.txt
  2. 37 9
      src/server/handler.py
  3. 0 5
      src/server/server.py

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ accelerate==0.10.0
 huggingface-hub==0.7.0
 transformers==4.21.3
 protobuf>=3.20.3,<4.0dev
-git+https://github.com/learning-at-home/hivemind@1e4af434f35ad43208e7e5df569c5ff5eb79681b
+git+https://github.com/learning-at-home/hivemind@be88b4280cdd87432168e1da238e532f1364078b
 humanfriendly

+ 37 - 9
src/server/handler.py

@@ -75,10 +75,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
+
+        request = await anext(requests)
+        requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_inference.open", requested_uids, context)
         try:
-            logger.debug("Opened rpc_inference()")
-            request = await anext(requests)
-            requested_uids = self._check_uids(request.uid)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
@@ -167,12 +168,14 @@ class TransformerConnectionHandler(ConnectionHandler):
                     prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
-            logger.debug("Closed rpc_inference()")
+            self._log_request("rpc_inference.close", requested_uids, context)
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
         flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_forward", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         points = metadata.get("points", 0)
@@ -199,6 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Parse requests and prepare backends
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
+        self._log_request("rpc_forward_stream", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         points = metadata.get("points", 0)
         assert isinstance(
@@ -227,6 +232,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Parse requests and prepare backends
         flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_backward", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         points = metadata.get("points", 0)
@@ -257,9 +264,10 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
-
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
+        self._log_request("rpc_backward_stream", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         points = metadata.get("points", 0)
         assert isinstance(
@@ -307,19 +315,39 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []
+            total_size = 0
+            backend = None
             for backend in backends:
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
 
-                cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
-                )
+                descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
-                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
+                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
+                total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
+
+            gib = 1024**3
+            if backend is not None:
+                cur_size = backend.memory_cache.current_size_bytes
+                max_size = backend.memory_cache.max_size_bytes
+                friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+                cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+            else:
+                cache_stats = f"cache stats n/a"
+            logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
 
             yield handles
 
+    def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
+        friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
+        friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
+        friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
+
+        friendly_remote_id = "..." + str(context.remote_id)[-6:]
+
+        logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
+
 
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,

+ 0 - 5
src/server/server.py

@@ -379,11 +379,6 @@ class ModuleContainer(threading.Thread):
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for expert_name, backend in self.module_backends.items():
-            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
-
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)