|
@@ -137,16 +137,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert isinstance(
|
|
|
backend.inference_pool, PrioritizedTaskPool
|
|
|
), "petals support only prioritized pools"
|
|
|
- priority = self._prioritizer(
|
|
|
+ priority = self._prioritizer.prioritize(
|
|
|
cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
|
|
|
)
|
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
- cache_metadata,
|
|
|
- hidden_states,
|
|
|
- hypo_ids,
|
|
|
- priority=priority,
|
|
|
- backend=backend,
|
|
|
- type="inference",
|
|
|
+ cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
|
)
|
|
|
|
|
|
# serialize and send last layer outputs
|