|
@@ -107,7 +107,7 @@ async def run_remote_forward(
|
|
|
else:
|
|
|
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
|
- input_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
|
|
|
+ input_was_compressed = serialized_tensors[0].compression != runtime_pb2.CompressionType.NONE
|
|
|
residual = deserialize_torch_tensor(serialized_tensors[0]) if input_was_compressed else inputs[0]
|
|
|
deserialized_outputs[0].add_(residual)
|
|
|
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
|
@@ -155,7 +155,9 @@ async def run_remote_backward(
|
|
|
else:
|
|
|
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
|
- grad_output_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
|
|
|
- residual_grad = deserialize_torch_tensor(serialized_tensors[0]) if grad_output_was_compressed else inputs[0]
|
|
|
+ grad_output = grad_outputs_cpu[0]
|
|
|
+ assert inputs_and_grad_outputs[1] is grad_output
|
|
|
+ grad_output_was_compressed = serialized_tensors[1].compression != runtime_pb2.CompressionType.NONE
|
|
|
+ residual_grad = deserialize_torch_tensor(serialized_tensors[1]) if grad_output_was_compressed else grad_output
|
|
|
deserialized_grad_inputs[0].add_(residual_grad)
|
|
|
return deserialized_grad_inputs
|