justheuristic 3 年之前
父節點
當前提交
d94549eda8
共有 2 個文件被更改,包括 3 次插入5 次删除
  1. 2 2
      cli/run_server.py
  2. 1 3
      src/server/server.py

+ 2 - 2
cli/run_server.py

@@ -34,8 +34,8 @@ def main():
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of tokens in the same batch will not exceed this value')
-    parser.add_argument('--inference_max_length', type=int, default=None,
-                        help='Maximum total sequence length permitted per inference, defaults to max_batch_size tokens')
+    parser.add_argument('--inference_max_length', type=int, default=16384,
+                        help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
     parser.add_argument('--cache_dir', type=str, default=None, 
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--device', type=str, default=None, required=False,

+ 1 - 3
src/server/server.py

@@ -106,7 +106,7 @@ class Server(threading.Thread):
         num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
-        inference_max_length: Optional[int] = None,
+        inference_max_length: int = 4096,
         torch_dtype: str = "auto",
         revision: str = "main",
         cache_dir: Optional[str] = None,
@@ -138,8 +138,6 @@ class Server(threading.Thread):
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
-        if inference_max_length is None:
-            inference_max_length = max_batch_size
 
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]