Artem Chumachenko před 3 roky
rodič
revize
1d7c550485
2 změnil soubory, kde provedl 2 přidání a 2 odebrání
  1. 1 1
      src/client/remote_block.py
  2. 1 1
      src/client/remote_sequential.py

+ 1 - 1
src/client/remote_block.py

@@ -100,7 +100,7 @@ class RemoteTransformerBlockInferenceSession:
         )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs
+        return outputs[0]
 
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

+ 1 - 1
src/client/remote_sequential.py

@@ -141,7 +141,7 @@ class RemoteSequentialInferenceSession:
     def step(self, inputs: torch.Tensor):
         assert not self.closed
         for session in self.active_sessions:
-            outputs = session.step(inputs)[0]
+            outputs = session.step(inputs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs