ソースを参照

benchmark_forward: Set MAX_TOKENS_IN_BATCH

Aleksandr Borzunov 2 年 前
コミット
02a406129f
1 ファイル変更3 行追加0 行削除
  1. 3 0
      src/petals/cli/benchmark_forward.py

+ 3 - 0
src/petals/cli/benchmark_forward.py

@@ -5,12 +5,15 @@ import multiprocessing as mp
 from time import perf_counter
 
 import torch
+import petals.client.sequential_autograd
 from hivemind.utils.logging import get_logger
 from petals import DistributedBloomForCausalLM
 from transformers import BloomTokenizerFast
 
 logger = get_logger()
 
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+
 
 def main():
     parser = argparse.ArgumentParser()