5
0
Эх сурвалжийг харах

Improve server's logging (#96)

Log all RPC calls with block indices and shortened peer IDs, print attention cache stats.
Alexander Borzunov 2 жил өмнө
parent
commit
d8ef09146e

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ accelerate==0.10.0
 huggingface-hub==0.7.0
 huggingface-hub==0.7.0
 transformers==4.21.3
 transformers==4.21.3
 protobuf>=3.20.3,<4.0dev
 protobuf>=3.20.3,<4.0dev
-git+https://github.com/learning-at-home/hivemind@1e4af434f35ad43208e7e5df569c5ff5eb79681b
+git+https://github.com/learning-at-home/hivemind@be88b4280cdd87432168e1da238e532f1364078b
 humanfriendly
 humanfriendly

+ 37 - 9
src/server/handler.py

@@ -75,10 +75,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         context: P2PContext,
         context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
+
+        request = await anext(requests)
+        requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_inference.open", requested_uids, context)
         try:
         try:
-            logger.debug("Opened rpc_inference()")
-            request = await anext(requests)
-            requested_uids = self._check_uids(request.uid)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
             max_length = metadata.get("max_length")
@@ -167,12 +168,14 @@ class TransformerConnectionHandler(ConnectionHandler):
                     prefix_length += hidden_states.shape[1]
                     prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
                     request = await (anext(requests))
         finally:
         finally:
-            logger.debug("Closed rpc_inference()")
+            self._log_request("rpc_inference.close", requested_uids, context)
 
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
         # Parse request and prepare backends
         flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_forward", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         points = metadata.get("points", 0)
         points = metadata.get("points", 0)
@@ -199,6 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Parse requests and prepare backends
         # Parse requests and prepare backends
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_uids = self._check_uids(uid_str)
+        self._log_request("rpc_forward_stream", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         points = metadata.get("points", 0)
         points = metadata.get("points", 0)
         assert isinstance(
         assert isinstance(
@@ -227,6 +232,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Parse requests and prepare backends
         # Parse requests and prepare backends
         flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_uids = self._check_uids(request.uid)
+        self._log_request("rpc_backward", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         points = metadata.get("points", 0)
         points = metadata.get("points", 0)
@@ -257,9 +264,10 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
-
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_uids = self._check_uids(uids_header)
+        self._log_request("rpc_backward_stream", requested_uids, context)
+
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         points = metadata.get("points", 0)
         points = metadata.get("points", 0)
         assert isinstance(
         assert isinstance(
@@ -307,19 +315,39 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Allocate memory caches for each transformer block, return cache handles"""
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
         async with contextlib.AsyncExitStack() as stack:
             handles = []
             handles = []
+            total_size = 0
+            backend = None
             for backend in backends:
             for backend in backends:
                 num_heads = backend.module.self_attention.num_heads
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
                 head_dim = backend.module.self_attention.head_dim
 
 
-                cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
-                )
+                descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
 
-                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
+                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
+                total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
+
+            gib = 1024**3
+            if backend is not None:
+                cur_size = backend.memory_cache.current_size_bytes
+                max_size = backend.memory_cache.max_size_bytes
+                friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+                cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+            else:
+                cache_stats = f"cache stats n/a"
+            logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
 
 
             yield handles
             yield handles
 
 
+    def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
+        friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
+        friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
+        friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
+
+        friendly_remote_id = "..." + str(context.remote_id)[-6:]
+
+        logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
+
 
 
 async def _rpc_forward(
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     *flat_tensors: torch.Tensor,

+ 0 - 5
src/server/server.py

@@ -379,11 +379,6 @@ class ModuleContainer(threading.Thread):
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
-        logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for expert_name, backend in self.module_backends.items():
-            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
-
         if not self.dht.is_alive():
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
             self.dht.run_in_background(await_ready=True)