소스 검색

pass args/kwargs via forward

Your Name 1 년 전
부모
커밋
aacd8b2f9d
1개의 변경된 파일3개의 추가작업 그리고 3개의 파일을 삭제
  1. 3 3
      src/petals/client/remote_sequential.py

+ 3 - 3
src/petals/client/remote_sequential.py

@@ -49,13 +49,13 @@ class RemoteSequential(nn.Module):
 
         self._active_session = ContextVar("active_session", default=None)
 
-    def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+    def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor:
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         if self.active_session is None:
             assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
-            return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+            return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs)
         else:
-            return self.active_session.step(inputs, prompts, **kwargs)
+            return self.active_session.step(inputs, prompts, *args, **kwargs)
 
     @property
     def active_session(self) -> Optional[InferenceSession]: