浏览代码

benchmark_forward: Set MAX_TOKENS_IN_BATCH

Aleksandr Borzunov 2 年之前
父节点
当前提交
02a406129f
共有 1 个文件被更改,包括 3 次插入0 次删除
  1. 3 0
      src/petals/cli/benchmark_forward.py

+ 3 - 0
src/petals/cli/benchmark_forward.py

@@ -5,12 +5,15 @@ import multiprocessing as mp
 from time import perf_counter
 
 import torch
+import petals.client.sequential_autograd
 from hivemind.utils.logging import get_logger
 from petals import DistributedBloomForCausalLM
 from transformers import BloomTokenizerFast
 
 logger = get_logger()
 
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+
 
 def main():
     parser = argparse.ArgumentParser()