justheuristic hace 2 años
padre
commit
aa6badc49c
Se han modificado 2 ficheros con 3 adiciones y 8 borrados
  1. 2 7
      src/server/handler.py
  2. 1 1
      src/server/task_pool.py

+ 2 - 7
src/server/handler.py

@@ -137,16 +137,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert isinstance(
                             backend.inference_pool, PrioritizedTaskPool
                         ), "petals support only prioritized pools"
-                        priority = self._prioritizer(
+                        priority = self._prioritizer.prioritize(
                             cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
                         )
                         (hidden_states,) = await backend.inference_pool.submit_task(
-                            cache_metadata,
-                            hidden_states,
-                            hypo_ids,
-                            priority=priority,
-                            backend=backend,
-                            type="inference",
+                            cache_metadata, hidden_states, hypo_ids, priority=priority
                         )
 
                     # serialize and send last layer outputs

+ 1 - 1
src/server/task_pool.py

@@ -96,7 +96,7 @@ class PrioritizedTaskPool(TaskPoolBase):
         self.terminate()
         self._prioritizer_thread.join(timeout)
 
-    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
+    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         task = Task(priority, time.monotonic(), MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size: