Переглянути джерело

serialize points in inference session

justheuristic 2 роки тому
батько
коміт
a86145bbc2
2 змінених файлів з 6 додано та 5 видалено
  1. 2 1
      src/client/inference_session.py
  2. 4 4
      src/server/handler.py

+ 2 - 1
src/client/inference_session.py

@@ -43,6 +43,7 @@ class RemoteTransformerBlockInferenceSession:
         outputs_aiter: AsyncIterator,
         *,
         max_length: int,
+        points: int,
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
@@ -50,7 +51,7 @@ class RemoteTransformerBlockInferenceSession:
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
         self.stepped = False
         self.closed = False
 

+ 4 - 4
src/server/handler.py

@@ -77,7 +77,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
-            points = metadata.get("__points", 0.0)
+            points = metadata.get("points", 0.0)
 
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
@@ -170,7 +170,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("__points", 0.0)
+        points = metadata.get("points", 0.0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_forward should have number of points as number or None, got {points}"
@@ -195,7 +195,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        points = metadata.get("__points", 0.0)
+        points = metadata.get("points", 0.0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_forward_stream should have number of points as number or None, got {points}"
@@ -224,7 +224,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("__points", 0.0)
+        points = metadata.get("points", 0.0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_backward should have number of points as number or None, got {points}"