Sfoglia il codice sorgente

Enable rebalancing by default (#84)

Alexander Borzunov 2 anni fa
parent
commit
ee4e69c254

+ 9 - 5
cli/run_server.py

@@ -1,3 +1,5 @@
+import argparse
+
 import configargparse
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
@@ -12,7 +14,8 @@ logger = get_logger(__file__)
 
 def main():
     # fmt:off
-    parser = configargparse.ArgParser(default_config_files=["config.yml"])
+    parser = configargparse.ArgParser(default_config_files=["config.yml"],
+                                      formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
     group = parser.add_mutually_exclusive_group(required=True)
@@ -80,10 +83,11 @@ def main():
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
 
-    parser.add_argument("--min_balance_quality", type=float, default=0.0,
-                        help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) "
-                             "goes below this threshold. Default: rebalancing is disabled")
-    parser.add_argument("--mean_balance_check_period", type=float, default=150,
+    parser.add_argument("--balance_quality", type=float, default=0.75,
+                        help="Rebalance the swarm if its throughput is worse than this share of the optimal "
+                             "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
+                             "on each check for debugging purposes.")
+    parser.add_argument("--mean_balance_check_period", type=float, default=60,
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")

+ 1 - 1
src/server/backend.py

@@ -61,7 +61,7 @@ class TransformerBackend(ModuleBackend):
                 if not is_dummy(hypo_ids):
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-                print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
+                logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
                 hidden_states, (new_k, new_v) = self.module.forward(
                     hidden_states, layer_past=layer_past, use_cache=True
                 )

+ 5 - 5
src/server/block_selection.py

@@ -62,9 +62,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
 
 
 def should_choose_other_blocks(
-    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
+    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
 ) -> bool:
-    if min_balance_quality > 1.0:
+    if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
     spans, throughputs = _compute_spans(module_infos)
@@ -99,8 +99,8 @@ def should_choose_other_blocks(
             throughputs[span.start : span.end] += span.throughput
 
     new_throughput = throughputs.min()
-    balance_quality = initial_throughput / new_throughput
-    logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")
+    actual_quality = initial_throughput / new_throughput
+    logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
 
     eps = 1e-6
-    return balance_quality < min_balance_quality - eps
+    return actual_quality < balance_quality - eps

+ 5 - 2
src/server/handler.py

@@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
+from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
@@ -24,6 +25,8 @@ from src.server.task_pool import PrioritizedTaskPool
 from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from src.utils.misc import DUMMY, is_dummy
 
+logger = get_logger(__file__)
+
 
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
@@ -73,7 +76,7 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
-            print("OPENED RPC_INFERENCE")
+            logger.debug("Opened rpc_inference()")
             request = await anext(requests)
             requested_uids = self._check_uids(request.uid)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
@@ -164,7 +167,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
-            print("CLOSED RPC_INFERENCE")
+            logger.debug("Closed rpc_inference()")
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends

+ 4 - 4
src/server/server.py

@@ -61,8 +61,8 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         prefetch_batches: int = 1,
         sender_threads: int = 1,
-        min_balance_quality: float = 0.0,
-        mean_balance_check_period: float = 150,
+        balance_quality: float = 0.75,
+        mean_balance_check_period: float = 60,
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
@@ -138,7 +138,7 @@ class Server(threading.Thread):
                 raise
             block_indices = range(first_block_index, last_block_index)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
-        self.min_balance_quality = min_balance_quality
+        self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
 
@@ -215,7 +215,7 @@ class Server(threading.Thread):
             return False
 
         module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
-        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
+        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
 
     def shutdown(self):
         self.stop.set()