Your Name 2 سال پیش
والد
کامیت
13c13d347a
2فایلهای تغییر یافته به همراه7 افزوده شده و 7 حذف شده
  1. 5 5
      src/petals/server/block_functions.py
  2. 2 2
      src/petals/server/handler.py

+ 5 - 5
src/petals/server/block_functions.py

@@ -35,7 +35,7 @@ async def run_rpc_forward(
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
-    structure: Any,
+    args_structure: Any,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -45,7 +45,7 @@ async def run_rpc_forward(
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
+    (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
@@ -247,10 +247,10 @@ async def iterate_rpc_inference(
 
 
 def _check_inputs(
-    requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any
+    requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
 ):
-    if structure is not None:
-        args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure)
+    if args_structure is not None:
+        args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
     else:
         args, *backend_kwargs = flat_tensors, {}  # backward compatibility
 

+ 2 - 2
src/petals/server/handler.py

@@ -368,7 +368,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                structure=args_structure,
+                args_structure=args_structure,
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -397,7 +397,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
                 points=points,
-                structure=args_structure,
+                args_structure=args_structure,
             )
 
             # Split the serialized_output for streaming and respond to client