|
@@ -78,7 +78,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
max_length = metadata.get("max_length")
|
|
|
- points = metadata.get("points", 0.0)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
|
|
if not requested_uids:
|
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
@@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- points = metadata.get("points", 0.0)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
@@ -196,7 +196,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- points = metadata.get("points", 0.0)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
@@ -225,7 +225,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- points = metadata.get("points", 0.0)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
@@ -257,7 +257,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
|
|
|
requested_uids = self._check_uids(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- points = metadata.get("__points", 0.0)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
@@ -321,7 +321,7 @@ async def _rpc_forward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
- points: float = 0.0,
|
|
|
+ points: int = 0,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
@@ -367,7 +367,7 @@ async def _rpc_backward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
- points: float = 0.0,
|
|
|
+ points: int = 0,
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
inputs, grad_outputs, prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|