@@ -12,7 +12,7 @@ from transformers import BloomTokenizerFast
logger = get_logger()
-petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+# petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
def main():
@@ -20,7 +20,7 @@ from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
-MAX_TOKENS_IN_BATCH = 512
+MAX_TOKENS_IN_BATCH = 1024
async def sequential_forward(