Your Name 2 years ago
parent
commit
ed8d7f41b8

+ 22 - 14
src/petals/server/backend.py

@@ -53,13 +53,22 @@ class TransformerBackend(ModuleBackend):
         max_batch_size = self.forward_pool.max_batch_size
         max_batch_size = self.forward_pool.max_batch_size
         device = self.module.devices[self.module.output_device_index]
         device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
+            lambda args, kwargs: self.inference_step(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_inference",
         )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
         self.forward_pool = PrioritizedTaskPool(
-            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
+            lambda args, kwargs: self.forward(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_forward",
         )
         )
         self.backward_pool = PrioritizedTaskPool(
         self.backward_pool = PrioritizedTaskPool(
-            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
+            lambda args, kwargs: self.backward(*args, **kwargs),
+            max_batch_size=max_batch_size,
+            device=device,
+            name=f"{self.name}_backward",
         )
         )
 
 
         self.dtype = backend_dtype
         self.dtype = backend_dtype
@@ -96,27 +105,25 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
             cache_tensors.extend((keys, values))
         return cache_tensors
         return cache_tensors
 
 
-    def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]:
+    def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
         with self._peft_module.using_adapter(active_adapter), torch.no_grad():
         with self._peft_module.using_adapter(active_adapter), torch.no_grad():
             return self.module(*args, **kwargs)
             return self.module(*args, **kwargs)
 
 
     def backward(
     def backward(
-        self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs
+        self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
         assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
             (outputs,) = self.module(*args, **kwargs)
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
-        return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None)
+        return nested_map(
+            lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)
+        )
 
 
     @torch.inference_mode()
     @torch.inference_mode()
     def inference_step(
     def inference_step(
-        self,
-        hidden_states: torch.Tensor,
-        hypo_ids: torch.LongTensor,
-        kwargs: Dict[str, torch.Tensor],
-        inference_info: InferenceMetadata,
+        self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
         seq_len = hidden_states.shape[1]
         seq_len = hidden_states.shape[1]
@@ -217,8 +224,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
     assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
     assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
     first_pool = next(iter(backends.values())).inference_pool
     first_pool = next(iter(backends.values())).inference_pool
+    merged_inference_func = _MergedInferenceStep(backends)
     merged_pool = PrioritizedTaskPool(
     merged_pool = PrioritizedTaskPool(
-        _MergedInferenceStep(backends),
+        lambda args, kwargs: merged_inference_func(*args, **kwargs),
         max_batch_size=first_pool.max_batch_size,
         max_batch_size=first_pool.max_batch_size,
         device=first_pool.device,
         device=first_pool.device,
         name=f"merged_inference",
         name=f"merged_inference",
@@ -237,9 +245,9 @@ class _MergedInferenceStep:
         self,
         self,
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
         hypo_ids: torch.LongTensor,
-        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
         inference_infos: Sequence[InferenceMetadata],
         inference_infos: Sequence[InferenceMetadata],
         *optional_prompts: Optional[torch.Tensor],
         *optional_prompts: Optional[torch.Tensor],
+        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
         assert (
         assert (
             len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
             len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
@@ -248,6 +256,6 @@ class _MergedInferenceStep:
             if optional_prompt is not None:
             if optional_prompt is not None:
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
             (hidden_states,) = self.backends[inference_info.uid].inference_step(
             (hidden_states,) = self.backends[inference_info.uid].inference_step(
-                hidden_states, hypo_ids, kwargs, inference_info
+                hidden_states, hypo_ids, inference_info, **kwargs
             )
             )
         return (hidden_states,)
         return (hidden_states,)

+ 15 - 6
src/petals/server/block_functions.py

@@ -66,8 +66,8 @@ async def run_rpc_forward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
         (hidden_states,) = await backend.forward_pool.submit_task(
+            active_adapter,
             hidden_states,
             hidden_states,
-            active_adapter=active_adapter,
             **kwargs,
             **kwargs,
             priority=priority,
             priority=priority,
             size=num_tokens,
             size=num_tokens,
@@ -113,7 +113,7 @@ async def run_rpc_backward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
         (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens
+            active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
         )
         )
 
 
         assert isinstance(hidden_states, torch.Tensor)
         assert isinstance(hidden_states, torch.Tensor)
@@ -131,7 +131,7 @@ async def run_rpc_backward(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
         )
         (grad_outputs,) = await backend.backward_pool.submit_task(
         (grad_outputs,) = await backend.backward_pool.submit_task(
-            inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens
+            active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
         )
         )
 
 
         assert isinstance(grad_outputs, torch.Tensor)
         assert isinstance(grad_outputs, torch.Tensor)
@@ -211,7 +211,7 @@ async def iterate_rpc_inference(
                     hypo_ids,
                     hypo_ids,
                     inference_infos,
                     inference_infos,
                     *prompts,
                     *prompts,
-                    backend_kwargs,
+                    backend_kwargs=backend_kwargs,
                     priority=priority,
                     priority=priority,
                     size=num_tokens,
                     size=num_tokens,
                 )
                 )
@@ -221,7 +221,13 @@ async def iterate_rpc_inference(
                 ):
                 ):
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     (hidden_states,) = await backend.inference_pool.submit_task(
                     (hidden_states,) = await backend.inference_pool.submit_task(
-                        hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens
+                        hidden_states,
+                        hypo_ids,
+                        inference_infos,
+                        prompt,
+                        backend_kwargs=(kwargs,),
+                        priority=priority,
+                        size=num_tokens,
                     )
                     )
 
 
         # serialize and send last layer outputs
         # serialize and send last layer outputs
@@ -250,6 +256,9 @@ def _check_inputs(
             f"(one for each block). Found {len(backend_kwargs)} instead."
             f"(one for each block). Found {len(backend_kwargs)} instead."
         )
         )
     if len(backend_kwargs) == 1:
     if len(backend_kwargs) == 1:
-        backend_kwargs = (backend_kwargs,) * len(requested_backends)
+        backend_kwargs = backend_kwargs * len(requested_backends)
     assert len(backend_kwargs) == len(requested_backends)
     assert len(backend_kwargs) == len(requested_backends)
+    for i, kwargs in enumerate(backend_kwargs):
+        if not isinstance(kwargs, dict):
+            raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
     return args, backend_kwargs
     return args, backend_kwargs

+ 14 - 11
src/petals/server/task_pool.py

@@ -4,14 +4,17 @@ import threading
 import time
 import time
 from concurrent.futures._base import PENDING
 from concurrent.futures._base import PENDING
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
+from functools import partial
 from queue import PriorityQueue
 from queue import PriorityQueue
 from typing import Any, List, Optional, Sequence, Tuple, Union
 from typing import Any, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
-from hivemind import get_logger
+from hivemind import get_logger, nested_map
 from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
 
+from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
+
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -19,9 +22,10 @@ logger = get_logger(__name__)
 class Task:
 class Task:
     priority: float
     priority: float
     time_submitted: float
     time_submitted: float
+    size: int
     future: MPFuture = field(compare=False)
     future: MPFuture = field(compare=False)
-    args: Sequence[Union[torch.Tensor, Any]] = field(compare=False)
-    size: int = 1
+    flat_tensors: Sequence[torch.Tensor] = field(compare=False)
+    structure: Any
 
 
     @property
     @property
     def uid(self) -> int:
     def uid(self) -> int:
@@ -105,14 +109,13 @@ class PrioritizedTaskPool(TaskPoolBase):
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             self.terminate()
             self.terminate()
 
 
-    def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFuture:
+    def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
         future = MPFuture()
         future = MPFuture()
         # Remove shmem from MPFuture. This disables the .cancel() feature but
         # Remove shmem from MPFuture. This disables the .cancel() feature but
         # saves the server from "could not unlink the shared memory file" crashes during rebalancing
         # saves the server from "could not unlink the shared memory file" crashes during rebalancing
         future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
         future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
-
-        task = Task(priority, time.monotonic(), future, args, size=size)
+        task = Task(priority, time.monotonic(), size, future, *pack_args_kwargs(*args, **kwargs))
         if task.size > self.max_batch_size:
         if task.size > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
             task.future.set_exception(exc)
@@ -125,25 +128,25 @@ class PrioritizedTaskPool(TaskPoolBase):
 
 
     def load_batch_to_runtime(
     def load_batch_to_runtime(
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
-    ) -> Tuple[Any, List[torch.Tensor]]:
+    ) -> Tuple[int, Any]:
         """receive next batch of arrays"""
         """receive next batch of arrays"""
         device = device if device is not None else self.device
         device = device if device is not None else self.device
         task = self._ordered_tasks.get(block=True, timeout=timeout)
         task = self._ordered_tasks.get(block=True, timeout=timeout)
-        batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
+        device_flat_tensors = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.flat_tensors]
         self._dispatched_tasks[task.uid] = task
         self._dispatched_tasks[task.uid] = task
         self.batch_receiver.recv()  # reduce the number of active batches
         self.batch_receiver.recv()  # reduce the number of active batches
         if not self._ordered_tasks.empty():
         if not self._ordered_tasks.empty():
             first_remaining_task: Task = self._ordered_tasks.queue[0]
             first_remaining_task: Task = self._ordered_tasks.queue[0]
             self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
             self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
-        return task.uid, batch_inputs
+        return task.uid, unpack_args_kwargs(device_flat_tensors, task.structure)
 
 
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
-        batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
+        batch_outputs = nested_map(partial(_move_to_device_if_tensor, device="cpu", share_memory=True), batch_outputs)
         task = self._dispatched_tasks.pop(uid, None)
         task = self._dispatched_tasks.pop(uid, None)
         if task is None:
         if task is None:
             logger.error(
             logger.error(
-                f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+                f"Internal error: task task with index {uid} is missing from the dictionary; Could not set result"
             )
             )
         else:
         else:
             task.future.set_result(batch_outputs)
             task.future.set_result(batch_outputs)

+ 1 - 1
src/petals/utils/packaging.py

@@ -1,4 +1,4 @@
-from typing import Any, Tuple, Sequence
+from typing import Any, Sequence, Tuple
 
 
 import torch
 import torch
 from hivemind import nested_flatten, nested_pack
 from hivemind import nested_flatten, nested_pack