Explorar el Código

make points optional

justheuristic hace 3 años
padre
commit
5b5342901f
Se han modificado 2 ficheros con 8 adiciones y 8 borrados
  1. 1 1
      src/client/inference_session.py
  2. 7 7
      src/server/handler.py

+ 1 - 1
src/client/inference_session.py

@@ -43,7 +43,7 @@ class RemoteTransformerBlockInferenceSession:
         outputs_aiter: AsyncIterator,
         *,
         max_length: int,
-        points: int,
+        points: int = 0,
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1

+ 7 - 7
src/server/handler.py

@@ -78,7 +78,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
-            points = metadata.get("points", 0.0)
+            points = metadata.get("points", 0)
 
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
@@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("points", 0.0)
+        points = metadata.get("points", 0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_forward should have number of points as number or None, got {points}"
@@ -196,7 +196,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        points = metadata.get("points", 0.0)
+        points = metadata.get("points", 0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_forward_stream should have number of points as number or None, got {points}"
@@ -225,7 +225,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("points", 0.0)
+        points = metadata.get("points", 0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_backward should have number of points as number or None, got {points}"
@@ -257,7 +257,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        points = metadata.get("__points", 0.0)
+        points = metadata.get("points", 0)
         assert isinstance(
             points, (float, int)
         ), f"rpc_backward_stream should have number of points as number or None, got {points}"
@@ -321,7 +321,7 @@ async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
     prioritizer: TaskPrioritizerBase,
-    points: float = 0.0,
+    points: int = 0,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -367,7 +367,7 @@ async def _rpc_backward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
     prioritizer: TaskPrioritizerBase,
-    points: float = 0.0,
+    points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype