5
0
Эх сурвалжийг харах

Make throughput obligatory

Aleksandr Borzunov 3 жил өмнө
parent
commit
f73c655c82

+ 2 - 0
cli/run_server.py

@@ -41,6 +41,8 @@ def main():
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
 
+    parser.add_argument('--throughput', type=float, default=1.0,
+                        help='Expected server throughput')
     parser.add_argument('--update_period', type=float, required=False, default=30,
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 8 - 8
src/dht_utils.py

@@ -22,7 +22,7 @@ def declare_active_modules(
     dht: DHT,
     uids: Sequence[ModuleUID],
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    throughput: float,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
@@ -30,7 +30,7 @@ def declare_active_modules(
 
     :param uids: a list of module ids to declare
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param throughput: optionally specify your performance in terms of compute throughput
+    :param throughput: specify your performance in terms of compute throughput
     :param expiration_time: declated modules will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
@@ -51,7 +51,7 @@ async def _declare_active_modules(
     node: DHTNode,
     uids: List[ModuleUID],
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    throughput: float,
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
@@ -124,12 +124,12 @@ async def _get_remote_module_infos(
             continue
         servers = {}
         for peer_id, throughput in metadata.value.items():
-            if throughput is None:
-                throughput = 0.0  # FIXME:
             try:
-                servers[peer_id] = ServerInfo(ServerState.ONLINE, throughput)
-            except (ValueError, TypeError):
-                logger.error(f"Incorrect peer entry for {uid}: {peer_id}")
+                if not isinstance(throughput.value, float):
+                    raise ValueError(f'Throughput expected to be a float, not {throughput.value}')
+                servers[peer_id] = ServerInfo(ServerState.ONLINE, throughput.value)
+            except (ValueError, TypeError) as e:
+                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules

+ 1 - 1
src/server/load_balancing.py

@@ -9,7 +9,7 @@ def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[Remot
         if module is None:
             throughputs.append(0)
             continue
-        throughputs.append(sum(server.throughput for server in module.server.values()
+        throughputs.append(sum(server.throughput for server in module.servers.values()
                                if server.state != ServerState.OFFLINE))
 
     options = [(throughputs[i:i + num_blocks], i)

+ 26 - 5
src/server/server.py

@@ -34,6 +34,7 @@ class Server(threading.Thread):
         *,
         device: torch.device,
         num_connection_handlers: int = 8,
+        throughput: float,
         update_period: float = 30,
         expiration: Optional[float] = None,
         start: bool,
@@ -46,7 +47,12 @@ class Server(threading.Thread):
         ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends, dht, update_period, expiration, daemon=True
+            self.module_backends,
+            dht,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
         )
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
@@ -88,6 +94,7 @@ class Server(threading.Thread):
         cls,
         prefix: Optional[str],
         converted_model_name_or_path: str,
+        throughput: float,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
@@ -177,6 +184,7 @@ class Server(threading.Thread):
         return cls(
             dht,
             blocks,
+            throughput=throughput,
             num_connection_handlers=num_handlers,
             device=device,
             stats_report_interval=stats_report_interval,
@@ -241,18 +249,31 @@ class ModuleAnnouncerThread(threading.Thread):
     """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
 
     def __init__(
-        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+        self,
+        module_backends: Dict[str, TransformerBackend],
+        dht: DHT,
+        throughput: float,
+        update_period: float = 30,
+        expiration: Optional[int] = None,
+        **kwargs
     ):
         super().__init__(**kwargs)
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.module_backends = module_backends
         self.dht = dht
+        self.throughput = throughput
         self.update_period = update_period
         self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
-        while not self.stop.wait(self.update_period):
-            declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
+        while True:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                throughput=self.throughput,
+            )
+            if self.stop.wait(self.update_period):
+                break