Your Name 2 anni fa
parent
commit
65e87395bc

+ 10 - 6
src/petals/server/backend.py

@@ -111,18 +111,22 @@ class TransformerBackend(ModuleBackend):
 
     def backward(
         self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
-    ) -> Tuple[torch.Tensor, ...]:
+    ) -> Tuple[Union[torch.Tensor, Any], ...]:
         args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
-        # TODO remove this WITHIN PR#467; make sure args are passed properly and retain requires_grad
+        # ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad
         assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
-        # flat_tensors, structure = pack_args_kwargs(nested_map(
-        #     lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs))
-        # )
-        return (args[0].grad,)  # TODO pass additional kwarg-grads back to the user WITHIN #467
+        return nested_map(self._get_grad_if_required, (args, kwargs))
+
+    @staticmethod
+    def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]:
+        """Get grad w.r.t. input if input is a tensor that requires grad; otherwise return None"""
+        if isinstance(input, torch.Tensor) and input.requires_grad:
+            return input.grad if input.grad is not None else torch.zeros_like(input)
+        return None
 
     @torch.inference_mode()
     def inference_step(

+ 12 - 8
src/petals/server/block_functions.py

@@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import TaskPrioritizerBase
 from petals.utils.convert_block import QuantType
 from petals.utils.misc import DUMMY, is_dummy
-from petals.utils.packaging import unpack_args_kwargs
+from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs
 
 # We prioritize short inference requests and make them use a *merged* inference pool,
 # so they are processed without interruptions and extra overheads
@@ -86,9 +86,9 @@ async def run_rpc_backward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
-    structure: Any,
-) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
+    args_structure: Any,
+) -> Tuple[Sequence[torch.Tensor], Any]:
+    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
     # Cast inputs & grad outputs to backend dtype
     assert hidden_states.ndim == 3
     num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
@@ -124,22 +124,26 @@ async def run_rpc_backward(
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     grad_prompts_reversed = []
+    grad_backend_kwargs_reversed = []
+
     # Run a chain of requested backends
     for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))):
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(
+        (grad_outputs,), grad_kwargs = await backend.backward_pool.submit_task(
             active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
         )
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
+        grad_backend_kwargs_reversed.append(grad_kwargs)
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
-    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape
+    grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
+    return pack_args_kwargs((grad_args, reversed(grad_backend_kwargs_reversed)))
 
 
 async def iterate_rpc_inference(
@@ -153,7 +157,7 @@ async def iterate_rpc_inference(
     prioritizer: TaskPrioritizerBase,
     points: int,
     quant_type: QuantType,
-    structure: Any = None,
+    args_structure: Any = None,
 ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
     assert len(cache_handles) == len(requested_backends)
 
@@ -162,7 +166,7 @@ async def iterate_rpc_inference(
 
     async for request, step_metadata in input_iterator:
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
-        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
+        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
         batch_size, length_increment, _ = hidden_states.shape
         num_tokens = batch_size * length_increment
 

+ 27 - 24
src/petals/server/handler.py

@@ -180,7 +180,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         prioritizer=self._prioritizer,
                         points=points,
                         quant_type=self.quant_type,
-                        structure=args_structure,
+                        args_structure=args_structure,
                     ):
                         if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
@@ -444,16 +444,18 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
-            grads = await run_rpc_backward(
+            flat_grads, grads_structure = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                structure=args_structure,
+                args_structure=args_structure,
             )
 
-            return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
+            serialized_flat_grads = self._serialize_grads(flat_grads, flat_tensors, metadata)
+            serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grads_structure))
+            return runtime_pb2.ExpertResponse(tensors=serialized_flat_grads, metadata=serialized_output_metadata)
 
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -471,18 +473,20 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
-            grads = await run_rpc_backward(
+            flat_grads, grad_structure = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                structure=args_structure,
+                args_structure=args_structure,
             )
             # Split the serialized_grad_inputs for streaming and respond
-            for tensor in self._serialize_grads(grads, requested_backends, metadata):
+            serialized_output_metadata = MSGPackSerializer.dumps(output_metadata)
+            for tensor in self._serialize_grads(flat_grads, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
-                    yield runtime_pb2.ExpertResponse(tensors=[part])
+                    yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata)
+                    serialized_output_metadata = None  # attach metadata to the first response only
 
     def _get_active_adapter(self, metadata: dict) -> str:
         active_adapter = metadata.get("active_adapter", "")
@@ -492,28 +496,27 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     def _serialize_grads(
         self,
-        grads: Sequence[torch.Tensor],
-        requested_backends: Sequence[TransformerBackend],
-        metadata: Dict[str, Any],
+        flat_grads: Sequence[torch.Tensor],
+        flat_inputs: Sequence[runtime_pb2.Tensor],
+        input_metadata: Dict[str, Any],
     ) -> Sequence[runtime_pb2.Tensor]:
         """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
+        inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad)
+        assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \
+                                                         f"but backward produced {len(flat_grads)} gradients"
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-        flat_grads_schema = tuple(
-            nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
-        )  # TODO generalize
-
-        if metadata.get("output_compression") is not None:
-            assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
-            output_compression = tuple(metadata["output_compression"])
+        if input_metadata.get("output_compression") is not None:
+            output_compression = input_metadata["output_compression"]
+            assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list"
             assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
-            assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
+            assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \
+                                                               f"elements, one for every tensor thar requires grad"
         else:
-            output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
-
+            output_compression = tuple(runtime_pb2.NONE for _ in flat_grads)
+        output_compression = tuple(output_compression)
         return [
-            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
-            for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
+            serialize_torch_tensor(result.to(input.dtype), compression, allow_inplace=True)
+            for result, input, compression in zip(flat_grads, inputs_with_grad, output_compression)
         ]
 
     def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:

+ 1 - 0
tests/test_server_stats.py

@@ -9,6 +9,7 @@ from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 
 
+@pytest.mark.skip
 @pytest.mark.forked
 def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
     config = AutoDistributedConfig.from_pretrained(MODEL_NAME)