Your Name 2 years ago
parent
commit
ed8d7f41b8

+ 22 - 14
src/petals/server/backend.py

@@ -53,13 +53,22 @@ class TransformerBackend(ModuleBackend):
         max_batch_size = self.forward_pool.max_batch_size
         device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
+            lambda args, kwargs: self.inference_step(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_inference",
         )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
-            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
+            lambda args, kwargs: self.forward(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_forward",
         )
         self.backward_pool = PrioritizedTaskPool(
-            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
+            lambda args, kwargs: self.backward(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_backward",
         )
 
         self.dtype = backend_dtype
@@ -96,27 +105,25 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
-    def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]:
+    def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
         with self._peft_module.using_adapter(active_adapter), torch.no_grad():
             return self.module(*args, **kwargs)
 
     def backward(
-        self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs
+        self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
     ) -> Tuple[torch.Tensor, ...]:
         assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
-        return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None)
+        return nested_map(
+            lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)
+        )
 
     @torch.inference_mode()
     def inference_step(
-        self,
-        hidden_states: torch.Tensor,
-        hypo_ids: torch.LongTensor,
-        kwargs: Dict[str, torch.Tensor],
-        inference_info: InferenceMetadata,
+        self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
         seq_len = hidden_states.shape[1]
@@ -217,8 +224,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
     assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
     first_pool = next(iter(backends.values())).inference_pool
+    merged_inference_func = _MergedInferenceStep(backends)
     merged_pool = PrioritizedTaskPool(
-        _MergedInferenceStep(backends),
+        lambda args, kwargs: merged_inference_func(*args, **kwargs),
         max_batch_size=first_pool.max_batch_size,
         device=first_pool.device,
         name=f"merged_inference",
@@ -237,9 +245,9 @@ class _MergedInferenceStep:
         self,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
-        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
         inference_infos: Sequence[InferenceMetadata],
         *optional_prompts: Optional[torch.Tensor],
+        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
     ) -> Tuple[torch.Tensor, ...]:
         assert (
             len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
@@ -248,6 +256,6 @@ class _MergedInferenceStep:
             if optional_prompt is not None:
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
             (hidden_states,) = self.backends[inference_info.uid].inference_step(
-                hidden_states, hypo_ids, kwargs, inference_info
+                hidden_states, hypo_ids, inference_info, **kwargs
             )
         return (hidden_states,)

+ 15 - 6
src/petals/server/block_functions.py

@@ -66,8 +66,8 @@ async def run_rpc_forward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
+            active_adapter,
             hidden_states,
-            active_adapter=active_adapter,
             **kwargs,
             priority=priority,
             size=num_tokens,
@@ -113,7 +113,7 @@ async def run_rpc_backward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens
+            active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
         )
 
         assert isinstance(hidden_states, torch.Tensor)
@@ -131,7 +131,7 @@ async def run_rpc_backward(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
         (grad_outputs,) = await backend.backward_pool.submit_task(
-            inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens
+            active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
         )
 
         assert isinstance(grad_outputs, torch.Tensor)
@@ -211,7 +211,7 @@ async def iterate_rpc_inference(
                     hypo_ids,
                     inference_infos,
                     *prompts,
-                    backend_kwargs,
+                    backend_kwargs=backend_kwargs,
                     priority=priority,
                     size=num_tokens,
                 )
@@ -221,7 +221,13 @@ async def iterate_rpc_inference(
                 ):
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     (hidden_states,) = await backend.inference_pool.submit_task(
-                        hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens
+                        hidden_states,
+                        hypo_ids,
+                        inference_infos,
+                        prompt,
+                        backend_kwargs=(kwargs,),
+                        priority=priority,
+                        size=num_tokens,
                     )
 
         # serialize and send last layer outputs
@@ -250,6 +256,9 @@ def _check_inputs(
             f"(one for each block). Found {len(backend_kwargs)} instead."
         )
     if len(backend_kwargs) == 1:
-        backend_kwargs = (backend_kwargs,) * len(requested_backends)
+        backend_kwargs = backend_kwargs * len(requested_backends)
     assert len(backend_kwargs) == len(requested_backends)
+    for i, kwargs in enumerate(backend_kwargs):
+        if not isinstance(kwargs, dict):
+            raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
     return args, backend_kwargs

+ 14 - 11
src/petals/server/task_pool.py

@@ -4,14 +4,17 @@ import threading
 import time
 from concurrent.futures._base import PENDING
 from dataclasses import dataclass, field
+from functools import partial
 from queue import PriorityQueue
 from typing import Any, List, Optional, Sequence, Tuple, Union
 
 import torch
-from hivemind import get_logger
+from hivemind import get_logger, nested_map
 from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
+from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
+
 logger = get_logger(__name__)
 
 
@@ -19,9 +22,10 @@ logger = get_logger(__name__)
 class Task:
     priority: float
     time_submitted: float
+    size: int
     future: MPFuture = field(compare=False)
-    args: Sequence[Union[torch.Tensor, Any]] = field(compare=False)
-    size: int = 1
+    flat_tensors: Sequence[torch.Tensor] = field(compare=False)
+    structure: Any
 
     @property
     def uid(self) -> int:
@@ -105,14 +109,13 @@ class PrioritizedTaskPool(TaskPoolBase):
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             self.terminate()
 
-    def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFuture:
+    def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         future = MPFuture()
         # Remove shmem from MPFuture. This disables the .cancel() feature but
         # saves the server from "could not unlink the shared memory file" crashes during rebalancing
         future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
-
-        task = Task(priority, time.monotonic(), future, args, size=size)
+        task = Task(priority, time.monotonic(), size, future, *pack_args_kwargs(*args, **kwargs))
         if task.size > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
@@ -125,25 +128,25 @@ class PrioritizedTaskPool(TaskPoolBase):
 
     def load_batch_to_runtime(
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
-    ) -> Tuple[Any, List[torch.Tensor]]:
+    ) -> Tuple[int, Any]:
         """receive next batch of arrays"""
         device = device if device is not None else self.device
         task = self._ordered_tasks.get(block=True, timeout=timeout)
-        batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
+        device_flat_tensors = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.flat_tensors]
         self._dispatched_tasks[task.uid] = task
         self.batch_receiver.recv()  # reduce the number of active batches
         if not self._ordered_tasks.empty():
             first_remaining_task: Task = self._ordered_tasks.queue[0]
             self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
-        return task.uid, batch_inputs
+        return task.uid, unpack_args_kwargs(device_flat_tensors, task.structure)
 
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
-        batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
+        batch_outputs = nested_map(partial(_move_to_device_if_tensor, device="cpu", share_memory=True), batch_outputs)
         task = self._dispatched_tasks.pop(uid, None)
         if task is None:
             logger.error(
-                f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+                f"Internal error: task task with index {uid} is missing from the dictionary; Could not set result"
             )
         else:
             task.future.set_result(batch_outputs)

+ 1 - 1
src/petals/utils/packaging.py

@@ -1,4 +1,4 @@
-from typing import Any, Tuple, Sequence
+from typing import Any, Sequence, Tuple
 
 import torch
 from hivemind import nested_flatten, nested_pack