|
@@ -78,12 +78,12 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
|
break # this message means "done sending"
|
|
|
|
|
|
- def step(self, new_hidden_states: torch.Tensor, batch_ids: torch.Tensor):
|
|
|
+ def step(self, new_hidden_states: torch.Tensor):
|
|
|
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
|
|
|
if self.closed:
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
# serialize inputs and put them into the queue
|
|
|
- inputs = (new_hidden_states, batch_ids)
|
|
|
+ inputs = (new_hidden_states,)
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
self._step(
|
|
|
runtime_pb2.ExpertRequest(
|