5
0
Эх сурвалжийг харах

benchmark_forward: Set MAX_TOKENS_IN_BATCH

Aleksandr Borzunov 2 жил өмнө
parent
commit
02a406129f

+ 3 - 0
src/petals/cli/benchmark_forward.py

@@ -5,12 +5,15 @@ import multiprocessing as mp
 from time import perf_counter
 
 import torch
+import petals.client.sequential_autograd
 from hivemind.utils.logging import get_logger
 from petals import DistributedBloomForCausalLM
 from transformers import BloomTokenizerFast
 
 logger = get_logger()
 
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+
 
 def main():
     parser = argparse.ArgumentParser()