فهرست منبع

finalize diff compression

justheuristic 2 سال پیش
والد
کامیت
984adc5b3f
2فایلهای تغییر یافته به همراه13 افزوده شده و 10 حذف شده
  1. 7 5
      src/client/inference_session.py
  2. 6 5
      tests/test_remote_sequential.py

+ 7 - 5
src/client/inference_session.py

@@ -105,14 +105,14 @@ class _ServerInferenceSession:
 
         # serialize inputs and put them into the queue
         inputs = (new_hidden_states, prompts, hypo_ids)
+        flat_inference_schema = nested_flatten(self.rpc_info["inference_schema"])
+        serialized_inputs = tuple(serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
+                                  for tensor, proto in zip(inputs, flat_inference_schema))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
-                    tensors=[
-                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
-                    ],
+                    tensors=serialized_inputs,
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )
             )
@@ -120,7 +120,9 @@ class _ServerInferenceSession:
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
         # add back residual connections after rpc_inference
-        return outputs[0].add_(new_hidden_states)
+        inputs_are_compressed = flat_inference_schema[0].compression != runtime_pb2.CompressionType.NONE
+        residuals = deserialize_torch_tensor(serialized_inputs[0]) if inputs_are_compressed else new_hidden_states
+        return outputs[0].add_(residuals)
 
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

+ 6 - 5
tests/test_remote_sequential.py

@@ -38,10 +38,10 @@ def test_remote_sequential():
     assert hidden.shape == test_inputs.shape
     assert hidden.requires_grad
     second_half_outputs = second_half(hidden)
-    assert torch.allclose(second_half_outputs, full_outputs)
+    assert torch.allclose(second_half_outputs, full_outputs, rtol=0, atol=1e-4)
 
     (second_half_outputs * grad_proj).sum().backward()
-    assert torch.allclose(test_inputs.grad, full_grad)
+    assert torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=3e-4)
 
 
 @pytest.mark.forked
@@ -79,11 +79,12 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
 
         block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
         (outputs_ref,) = block(outputs_ref)
+    outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat([inputs, input_prompts_ref], dim=1)
 
-    assert torch.allclose(outputs_ref, outputs)
+    assert torch.allclose(outputs_ref, outputs)  # exact match
 
     (outputs_ref * output_proj).sum().backward()
     assert input_prompts_ref.grad is not None
-    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, rtol=0, atol=1e-5)
     assert intermediate_prompts_ref.grad is not None
-    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
+    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, rtol=0, atol=1e-5)