Bläddra i källkod

remove residuals in all RPCs

justheuristic 2 år sedan
förälder
incheckning
2152036441
3 ändrade filer med 12 tillägg och 6 borttagningar
  1. 2 1
      src/client/inference_session.py
  2. 5 3
      src/client/remote_forward_backward.py
  3. 5 2
      src/server/handler.py

+ 2 - 1
src/client/inference_session.py

@@ -119,7 +119,8 @@ class _ServerInferenceSession:
         )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
+        # add back residual connections after rpc_inference
+        return outputs[0].add_(new_hidden_states)
 
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

+ 5 - 3
src/client/remote_forward_backward.py

@@ -107,7 +107,7 @@ async def run_remote_forward(
     else:
         deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
-    input_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
+    input_was_compressed = serialized_tensors[0].compression != runtime_pb2.CompressionType.NONE
     residual = deserialize_torch_tensor(serialized_tensors[0]) if input_was_compressed else inputs[0]
     deserialized_outputs[0].add_(residual)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@@ -155,7 +155,9 @@ async def run_remote_backward(
     else:
         deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
-    grad_output_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
-    residual_grad = deserialize_torch_tensor(serialized_tensors[0]) if grad_output_was_compressed else inputs[0]
+    grad_output = grad_outputs_cpu[0]
+    assert inputs_and_grad_outputs[1] is grad_output
+    grad_output_was_compressed = serialized_tensors[1].compression != runtime_pb2.CompressionType.NONE
+    residual_grad = deserialize_torch_tensor(serialized_tensors[1]) if grad_output_was_compressed else grad_output
     deserialized_grad_inputs[0].add_(residual_grad)
     return deserialized_grad_inputs

+ 5 - 2
src/server/handler.py

@@ -298,7 +298,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_backends: Sequence[TransformerBackend],
         metadata: Dict[str, Any],
     ) -> Sequence[runtime_pb2.Tensor]:
-        """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
+        """
+        Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema
+        :note: this code expects grads to be gradients w.r.t. inputs without residuals (as returned by rpc_backward)
+        """
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
         flat_grads_schema = tuple(
@@ -472,5 +475,5 @@ async def _rpc_backward(
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
-    grad_inputs = torch.sub(grad_prompts, original_grad_outputs, out=original_grad_outputs)  # remove residuals
+    grad_inputs = torch.sub(grad_outputs, original_grad_outputs, out=original_grad_outputs)  # remove residuals
     return [grad_inputs] if is_dummy(grad_prompts) else [grad_inputs, grad_prompts]  # TODO un-duct-tape