瀏覽代碼

Add checks for forward() inputs on the client side (#123)

justheuristic 2 年之前
父節點
當前提交
8491ed2bd3
共有 1 個文件被更改,包括 2 次插入0 次删除
  1. 2 0
      src/petals/client/remote_sequential.py

+ 2 - 0
src/petals/client/remote_sequential.py

@@ -53,6 +53,8 @@ class RemoteSequential(nn.Module):
             self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
+        assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs