Your Name 2 سال پیش
والد
کامیت
f313730767
4فایلهای تغییر یافته به همراه105 افزوده شده و 65 حذف شده
  1. 35 15
      src/petals/server/backend.py
  2. 60 35
      src/petals/server/block_functions.py
  3. 5 5
      src/petals/server/handler.py
  4. 5 10
      src/petals/server/task_pool.py

+ 35 - 15
src/petals/server/backend.py

@@ -5,7 +5,7 @@ from itertools import chain
 from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 import torch
-from hivemind import BatchTensorDescriptor, TensorDescriptor
+from hivemind import BatchTensorDescriptor, TensorDescriptor, nested_flatten, nested_map
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
@@ -96,21 +96,26 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
-    def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
-        *inputs, active_adapter = inputs
-        with self._peft_module.using_adapter(active_adapter):
-            return super().forward(*inputs)
+    def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]:
+        with self._peft_module.using_adapter(active_adapter), torch.no_grad():
+            return self.module(*args, **kwargs)
 
-    def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
-        *inputs, active_adapter = inputs
-        with self._peft_module.using_adapter(active_adapter):
-            return super().backward(*inputs)
+    def backward(
+        self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs
+    ) -> Tuple[torch.Tensor, ...]:
+        assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
+        with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
+            (outputs,) = self.module(*args, **kwargs)
+            assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
+            torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
+        return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None)
 
     @torch.inference_mode()
     def inference_step(
         self,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
+        kwargs: Dict[str, torch.Tensor],
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
@@ -129,8 +134,9 @@ class TransformerBackend(ModuleBackend):
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
             for offset in range(0, seq_len, max_chunk_length):
                 hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
+                kwargs_chunk = self._select_kwargs_chunk(kwargs, seq_len, offset, max_chunk_length)
                 output_hidden_states_chunk, new_kvs = self.module.forward(
-                    hidden_states_chunk, layer_past=layer_past, use_cache=True
+                    hidden_states_chunk, layer_past=layer_past, use_cache=True, **kwargs_chunk
                 )
                 if seq_len > max_chunk_length:
                     output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
@@ -178,6 +184,17 @@ class TransformerBackend(ModuleBackend):
             new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
             cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
 
+    @staticmethod
+    def _select_kwargs_chunk(kwargs: Dict[str, Any], seq_len: int, offset: int, max_chunk_length: int):
+        if offset == 0 and max_chunk_length >= seq_len:
+            return kwargs
+        kwargs_chunk = {}
+        for key, value in kwargs.items():
+            if isinstance(value, torch.Tensor) and value.ndim >= 2 and value.shape[-2] == seq_len:
+                value = value[:, offset : offset + max_chunk_length]
+            kwargs_chunk[key] = value
+        return kwargs_chunk
+
     def get_pools(self) -> Sequence[PrioritizedTaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
 
@@ -220,14 +237,17 @@ class _MergedInferenceStep:
         self,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
+        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
         inference_infos: Sequence[InferenceMetadata],
         *optional_prompts: Optional[torch.Tensor],
     ) -> Tuple[torch.Tensor, ...]:
-        assert len(inference_infos) == len(
-            optional_prompts
-        ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
-        for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
+        assert (
+            len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
+        ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(backend_kwargs)} kwargs"
+        for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, backend_kwargs):
             if optional_prompt is not None:
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
-            (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
+            (hidden_states,) = self.backends[inference_info.uid].inference_step(
+                hidden_states, hypo_ids, kwargs, inference_info
+            )
         return (hidden_states,)

+ 60 - 35
src/petals/server/block_functions.py

@@ -35,7 +35,7 @@ async def run_rpc_forward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
-    args_structure: Any = None,
+    structure: Any,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -45,22 +45,19 @@ async def run_rpc_forward(
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    if args_structure is not None:
-        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
-        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
-    hidden_states, prompts, *_ = flat_tensors
-
+    (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
+    num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
     if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
     else:
         prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a chain of requested backends
-    for backend, prompt in zip(requested_backends, prompts):
+    for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs):
         if not is_dummy(prompt):
             hidden_states[:, : prompt.shape[1]] += prompt
 
@@ -70,8 +67,10 @@ async def run_rpc_forward(
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
             hidden_states,
-            active_adapter,
+            active_adapter=active_adapter,
+            **kwargs,
             priority=priority,
+            size=num_tokens,
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -87,15 +86,13 @@ async def run_rpc_backward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
-    args_structure: Any = None,
+    structure: Any,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    if args_structure is not None:
-        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
-        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
-    inputs, grad_outputs, prompts, *_ = flat_tensors
-
+    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
     # Cast inputs & grad outputs to backend dtype
-    inputs = inputs.to(requested_backends[0].dtype)
+    assert hidden_states.ndim == 3
+    num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
+    hidden_states = hidden_states.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
     if prompts is None or is_dummy(prompts):
@@ -106,32 +103,36 @@ async def run_rpc_backward(
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
     inter_inputs = []
-    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
-        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+    for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs):
+        assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs[:, : prompt.shape[1]] += prompt
-        inter_inputs.append(inputs)
+            hidden_states[:, : prompt.shape[1]] += prompt
+        inter_inputs.append(hidden_states)
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
         priority = prioritizer.prioritize(
-            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
 
-        assert isinstance(inputs, torch.Tensor)
+        assert isinstance(hidden_states, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
-        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
-    inter_inputs.append(inputs)
+        hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
+    inter_inputs.append(hidden_states)
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     grad_prompts_reversed = []
     # Run a chain of requested backends
-    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+    for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))):
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(
+            inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens
+        )
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
@@ -152,7 +153,7 @@ async def iterate_rpc_inference(
     prioritizer: TaskPrioritizerBase,
     points: int,
     quant_type: QuantType,
-    args_structure: Any = None,
+    structure: Any = None,
 ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
     assert len(cache_handles) == len(requested_backends)
 
@@ -161,12 +162,9 @@ async def iterate_rpc_inference(
 
     async for request, step_metadata in input_iterator:
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
-        if args_structure is not None:
-            # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
-            flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
-
-        hidden_states, prompts, hypo_ids, *_ = flat_tensors
+        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
         batch_size, length_increment, _ = hidden_states.shape
+        num_tokens = batch_size * length_increment
 
         # Cast inputs to backend dtype
         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -209,13 +207,21 @@ async def iterate_rpc_inference(
                     for uid, handles in zip(requested_uids, cache_handles)
                 )
                 (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
-                    hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
+                    hidden_states,
+                    hypo_ids,
+                    inference_infos,
+                    *prompts,
+                    backend_kwargs,
+                    priority=priority,
+                    size=num_tokens,
                 )
             else:
-                for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
+                for backend, uid, handles, prompt, kwargs in zip(
+                    requested_backends, requested_uids, cache_handles, prompts, backend_kwargs
+                ):
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     (hidden_states,) = await backend.inference_pool.submit_task(
-                        hidden_states, hypo_ids, inference_infos, prompt, priority=priority
+                        hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens
                     )
 
         # serialize and send last layer outputs
@@ -228,3 +234,22 @@ async def iterate_rpc_inference(
 
         # prepare for next step
         prefix_length += length_increment
+
+
+def _check_inputs(
+    requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any
+):
+    if structure is not None:
+        args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure)
+    else:
+        args, *backend_kwargs = flat_tensors, {}  # backward compatibility
+
+    if len(backend_kwargs) not in (1, len(requested_backends)):
+        raise RuntimeError(
+            f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
+            f"(one for each block). Found {len(backend_kwargs)} instead."
+        )
+    if len(backend_kwargs) == 1:
+        backend_kwargs = (backend_kwargs,) * len(requested_backends)
+    assert len(backend_kwargs) == len(requested_backends)
+    return args, backend_kwargs

+ 5 - 5
src/petals/server/handler.py

@@ -180,7 +180,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         prioritizer=self._prioritizer,
                         points=points,
                         quant_type=self.quant_type,
-                        args_structure=args_structure,
+                        structure=args_structure,
                     ):
                         if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
@@ -368,7 +368,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                args_structure=args_structure,
+                structure=args_structure,
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -397,7 +397,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                args_structure=args_structure,
+                structure=args_structure,
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -450,7 +450,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                args_structure=args_structure,
+                structure=args_structure,
             )
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -477,7 +477,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                args_structure=args_structure,
+                structure=args_structure,
             )
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):

+ 5 - 10
src/petals/server/task_pool.py

@@ -20,7 +20,8 @@ class Task:
     priority: float
     time_submitted: float
     future: MPFuture = field(compare=False)
-    args: Sequence[torch.Tensor] = field(compare=False)
+    args: Sequence[Union[torch.Tensor, Any]] = field(compare=False)
+    size: int = 1
 
     @property
     def uid(self) -> int:
@@ -104,15 +105,15 @@ class PrioritizedTaskPool(TaskPoolBase):
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             self.terminate()
 
-    def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
+    def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         future = MPFuture()
         # Remove shmem from MPFuture. This disables the .cancel() feature but
         # saves the server from "could not unlink the shared memory file" crashes during rebalancing
         future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
 
-        task = Task(priority, time.monotonic(), future, args)
-        if self.get_task_size(task) > self.max_batch_size:
+        task = Task(priority, time.monotonic(), future, args, size=size)
+        if task.size > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
         else:
@@ -122,12 +123,6 @@ class PrioritizedTaskPool(TaskPoolBase):
                 self.priority = (task.priority, task.time_submitted)
         return task.future
 
-    def get_task_size(self, task: Task) -> int:
-        """compute task processing complexity; defaults to the total number of tokens"""
-        if task.args and task.args[0].ndim >= 2:
-            return task.args[0].shape[0] * task.args[0].shape[1]
-        return 1
-
     def load_batch_to_runtime(
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
     ) -> Tuple[Any, List[torch.Tensor]]: