|
@@ -100,7 +100,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
)
|
|
)
|
|
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
|
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
|
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
|
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
|
- return outputs
|
|
|
|
|
|
+ return outputs[0]
|
|
|
|
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|