Explorar o código

unbreak everything

Your Name hai 1 ano
pai
achega
721f7d2db3

+ 1 - 3
src/petals/client/inference_session.py

@@ -43,7 +43,7 @@ class _ServerInferenceSession:
         **metadata,
     ):
         self.config = config
-        self.span, self.span_uids = span, span_uids
+        self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
         self.num_blocks = len(span_uids)
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
@@ -283,7 +283,6 @@ class InferenceSession:
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         *block_kwargs: Sequence[Dict[str, torch.Tensor]],
-        **kwargs,
     ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
@@ -328,7 +327,6 @@ class InferenceSession:
                         prompts[server_session.span.start : server_session.span.end],
                         *block_kwargs[server_session.span.start : server_session.span.end],
                         step_id=step_id,
-                        **kwargs,
                     )
 
                     server_idx += 1

+ 1 - 1
src/petals/client/remote_sequential.py

@@ -53,7 +53,7 @@ class RemoteSequential(nn.Module):
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         if self.active_session is None:
             assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
-            return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs)
+            return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args)
         else:
             return self.active_session.step(inputs, prompts, *args, **kwargs)