Ver código fonte

black-isort-clarify

Your Name 1 ano atrás
pai
commit
17d278e88a

+ 5 - 1
src/petals/client/remote_forward_backward.py

@@ -104,9 +104,13 @@ async def run_remote_forward(
     size = sum(t.element_size() * t.nelement() for t in flat_tensors)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
-    return await forward_fn(
+    output_tensors = await forward_fn(
         merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
     )
+    # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591
+    requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
+    output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]
+    return output_tensors
 
 
 async def run_remote_backward(

+ 1 - 0
src/petals/client/routing/sequence_manager.py

@@ -493,6 +493,7 @@ class RemoteSequenceManager:
         self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
     ) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
         """
+        return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server)
         :param peer_id: remote server's PeerID
         :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
         :param args: request-specific input tensors

+ 17 - 7
src/petals/server/block_functions.py

@@ -31,20 +31,27 @@ logger = get_logger(__name__)
 
 async def run_rpc_forward(
     *flat_tensors: torch.Tensor,
+    args_structure: Any,
     requested_backends: Sequence[TransformerBackend],
     active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
-    args_structure: Any,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
     :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
-    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
+    :note: see pack_args_kwargs function for the definition of args_structure
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :param active_adapter: the name of LoRA adapter to use; defaults to no adapter
+    :param prioritizer: assigns priorities to each sub-request based on the number of points
+    :param points: client-specified number of points, used to assign priorities
+    :param args_structure:
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
+    requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
+    flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
     (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
@@ -77,7 +84,7 @@ async def run_rpc_forward(
             hidden_states.ndim == 3
         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-    return hidden_states
+    return hidden_states.requires_grad_(requires_grad)
 
 
 async def run_rpc_backward(
@@ -88,19 +95,22 @@ async def run_rpc_backward(
     points: int = 0,
     args_structure: Any,
 ) -> Tuple[Sequence[torch.Tensor], Any]:
+    """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
     assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
     ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
         requested_backends, flat_tensors, args_structure
     )
+    input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
+
     # Cast inputs & grad outputs to backend dtype
     num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
-    hidden_states = hidden_states.to(requested_backends[0].dtype)
-    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+    hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
 
     if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
     else:
-        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+        prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
@@ -140,7 +150,7 @@ async def run_rpc_backward(
             active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
         )
         assert isinstance(grad_outputs, torch.Tensor)
-        if not is_dummy(prompt):
+        if not is_dummy(prompt) and prompts_requires_grad:
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
         grad_backend_kwargs_reversed.append(grad_kwargs)
 

+ 8 - 7
src/petals/server/handler.py

@@ -361,18 +361,19 @@ class TransformerConnectionHandler(ConnectionHandler):
             active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
             args_structure = metadata.get("args_structure")
+
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
-
             hidden_states = await run_rpc_forward(
                 *flat_inputs,
+                args_structure=args_structure,
                 requested_backends=requested_backends,
-                prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
+                prioritizer=self._prioritizer,
                 points=points,
-                args_structure=args_structure,
             )
+
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
             )
@@ -396,11 +397,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             hidden_states = await run_rpc_forward(
                 *flat_inputs,
+                args_structure=args_structure,
                 requested_backends=requested_backends,
-                prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
+                prioritizer=self._prioritizer,
                 points=points,
-                args_structure=args_structure,
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -450,8 +451,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             flat_grads, grads_structure = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
-                prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
+                prioritizer=self._prioritizer,
                 points=points,
                 args_structure=args_structure,
             )
@@ -479,8 +480,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             flat_grads, grad_structure = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
-                prioritizer=self._prioritizer,
                 active_adapter=active_adapter,
+                prioritizer=self._prioritizer,
                 points=points,
                 args_structure=args_structure,
             )

+ 2 - 2
tests/test_remote_sequential.py

@@ -73,8 +73,8 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
         rpc_info["forward_schema"] = (compressed_input_schema,), dict()  # (args, kwargs)
         return rpc_info
 
-    def get_request_metadata(self, protocol: str, *args, **kwargs):
-        metadata = super().get_request_metadata(protocol, *args, **kwargs)
+    def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs):
+        metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs)
         if protocol == "rpc_forward":
             metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
         elif protocol == "rpc_backward":