Prechádzať zdrojové kódy

serialize outputs structure

Your Name 2 rokov pred
rodič
commit
09e9da6eb1

+ 1 - 1
src/petals/server/backend.py

@@ -119,7 +119,7 @@ class TransformerBackend(ModuleBackend):
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
-        return nested_map(self._get_grad_if_required, (args, kwargs))
+        return nested_map(self._get_grad_if_required, (*args, kwargs))
 
     @staticmethod
     def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]:

+ 1 - 2
src/petals/server/block_functions.py

@@ -134,10 +134,9 @@ async def run_rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,), grad_kwargs = await backend.backward_pool.submit_task(
+        (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
             active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
         )
-
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))

+ 1 - 1
src/petals/server/handler.py

@@ -482,7 +482,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 args_structure=args_structure,
             )
             # Split the serialized_grad_inputs for streaming and respond
-            serialized_output_metadata = MSGPackSerializer.dumps(output_metadata)
+            serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grad_structure))
             for tensor in self._serialize_grads(flat_grads, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata)