Ver Fonte

black, isort

Your Name há 2 anos atrás
pai
commit
4529471f3f
2 ficheiros alterados com 15 adições e 7 exclusões
  1. 7 3
      src/petals/server/block_functions.py
  2. 8 4
      src/petals/server/handler.py

+ 7 - 3
src/petals/server/block_functions.py

@@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import TaskPrioritizerBase
 from petals.utils.convert_block import QuantType
 from petals.utils.misc import DUMMY, is_dummy
-from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs
+from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
 
 # We prioritize short inference requests and make them use a *merged* inference pool,
 # so they are processed without interruptions and extra overheads
@@ -88,7 +88,9 @@ async def run_rpc_backward(
     points: int = 0,
     args_structure: Any,
 ) -> Tuple[Sequence[torch.Tensor], Any]:
-    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
+    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(
+        requested_backends, flat_tensors, args_structure
+    )
     # Cast inputs & grad outputs to backend dtype
     assert hidden_states.ndim == 3
     num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
@@ -166,7 +168,9 @@ async def iterate_rpc_inference(
 
     async for request, step_metadata in input_iterator:
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
-        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
+        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
+            requested_backends, flat_tensors, args_structure
+        )
         batch_size, length_increment, _ = hidden_states.shape
         num_tokens = batch_size * length_increment
 

+ 8 - 4
src/petals/server/handler.py

@@ -502,15 +502,19 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> Sequence[runtime_pb2.Tensor]:
         """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
         inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad)
-        assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \
-                                                         f"but backward produced {len(flat_grads)} gradients"
+        assert len(flat_grads) == len(inputs_with_grad), (
+            f"user provides {len(inputs_with_grad)} inputs with grad, "
+            f"but backward produced {len(flat_grads)} gradients"
+        )
         # Modify grad_inputs_schema to support grad_prompts
         if input_metadata.get("output_compression") is not None:
             output_compression = input_metadata["output_compression"]
             assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list"
             assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
-            assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \
-                                                               f"elements, one for every tensor thar requires grad"
+            assert len(output_compression) == len(flat_grads), (
+                f"output_compression should have {len(flat_grads)} "
+                f"elements, one for every tensor thar requires grad"
+            )
         else:
             output_compression = tuple(runtime_pb2.NONE for _ in flat_grads)
         output_compression = tuple(output_compression)