瀏覽代碼

rollback: only generic kwarg

Your Name 1 年之前
父節點
當前提交
6c7f762379
共有 2 個文件被更改,包括 15 次插入12 次删除
  1. 3 12
      src/petals/server/backend.py
  2. 12 0
      src/petals/server/server.py

+ 3 - 12
src/petals/server/backend.py

@@ -53,22 +53,13 @@ class TransformerBackend(ModuleBackend):
         max_batch_size = self.forward_pool.max_batch_size
         device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
-            lambda args, kwargs: self.inference_step(*args, **kwargs),
-            max_batch_size=max_batch_size,
-            device=device,
-            name=f"{self.name}_inference",
+            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
         )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
-            lambda args, kwargs: self.forward(*args, **kwargs),
-            max_batch_size=max_batch_size,
-            device=device,
-            name=f"{self.name}_forward",
+            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
         )
         self.backward_pool = PrioritizedTaskPool(
-            lambda args, kwargs: self.backward(*args, **kwargs),
-            max_batch_size=max_batch_size,
-            device=device,
-            name=f"{self.name}_backward",
+            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
         )
 
         self.dtype = backend_dtype

+ 12 - 0
src/petals/server/server.py

@@ -770,3 +770,15 @@ class RuntimeWithDeduplicatedPools(Runtime):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.pools = tuple(set(self.pools))
+
+    def process_batch(
+        self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any]
+    ) -> Tuple[Any, int]:
+        """process one batch of tasks from a given pool, return a batch of results and total batch size"""
+        outputs = pool.process_func(*args, **kwargs)
+        batch_size = 1
+        for arg in args:
+            if isintance(arg, torch.Tensor) and arg.ndim > 2:
+                batch_size = arg.shape[0] * arg.shape[1]
+                break
+        return outputs, batch_size