|
@@ -141,7 +141,7 @@ class RemoteSequentialInferenceSession:
|
|
|
def step(self, inputs: torch.Tensor):
|
|
|
assert not self.closed
|
|
|
for session in self.active_sessions:
|
|
|
- outputs = session.step(inputs)[0]
|
|
|
+ outputs = session.step(inputs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
inputs = outputs
|
|
|
return inputs
|