Browse Source

Add rebalancing options to CLI args

Aleksandr Borzunov 2 năm trước cách đây
mục cha
commit
4e26b2799d
2 tập tin đã thay đổi với 11 bổ sung4 xóa
  1. 7 0
      cli/run_server.py
  2. 4 4
      src/server/server.py

+ 7 - 0
cli/run_server.py

@@ -79,6 +79,13 @@ def main():
     parser.add_argument('--custom_module_path', type=str, required=False,
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
+
+    parser.add_argument("--min_balance_quality", type=float, default=0.0,
+                        help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) "
+                             "goes below this threshold. Default: rebalancing is disabled")
+    parser.add_argument("--mean_balance_check_period", type=float, default=150,
+                        help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
+
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
     parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
 

+ 4 - 4
src/server/server.py

@@ -59,9 +59,9 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         prefetch_batches: int = 1,
         sender_threads: int = 1,
-        mean_block_selection_delay: float = 0.5,
-        mean_balance_check_period: float = 150,
         min_balance_quality: float = 0.0,
+        mean_balance_check_period: float = 150,
+        mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -136,9 +136,9 @@ class Server(threading.Thread):
                 raise
             block_indices = range(first_block_index, last_block_index)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
-        self.mean_block_selection_delay = mean_block_selection_delay
-        self.mean_balance_check_period = mean_balance_check_period
         self.min_balance_quality = min_balance_quality
+        self.mean_balance_check_period = mean_balance_check_period
+        self.mean_block_selection_delay = mean_block_selection_delay
 
         self.stop = threading.Event()
         if start: