Quellcode durchsuchen

Set dht.num_workers = n_layer, update_period = 150, expiration = 300 (#125)

Alexander Borzunov vor 2 Jahren
Ursprung
Commit
9dbf5e2e6f
3 geänderte Dateien mit 11 neuen und 11 gelöschten Zeilen
  1. 1 1
      src/petals/cli/run_server.py
  2. 1 1
      src/petals/client/remote_model.py
  3. 9 9
      src/petals/server/server.py

+ 1 - 1
src/petals/cli/run_server.py

@@ -73,7 +73,7 @@ def main():
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
-    parser.add_argument('--update_period', type=float, required=False, default=30,
+    parser.add_argument('--update_period', type=float, required=False, default=150,
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')

+ 1 - 1
src/petals/client/remote_model.py

@@ -82,7 +82,7 @@ class DistributedBloomModel(BloomModel):
         dht = (
             config.dht
             if config.dht is not None
-            else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+            else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, num_workers=n_layer, start=True)
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)

+ 9 - 9
src/petals/server/server.py

@@ -61,7 +61,7 @@ class Server:
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
-        update_period: float = 30,
+        update_period: float = 150,
         expiration: Optional[float] = None,
         request_timeout: float = 3 * 60,
         session_timeout: float = 30 * 60,
@@ -106,7 +106,14 @@ class Server:
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
-        self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        self.block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path,
+            use_auth_token=use_auth_token,
+            revision=revision,
+        )
+        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+
+        self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
             logger.info("Connecting to the public Petals swarm")
@@ -124,13 +131,6 @@ class Server:
             logger.info("Model weights will be loaded in 8-bit format")
         self.load_in_8bit = load_in_8bit
 
-        self.block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path,
-            use_auth_token=use_auth_token,
-            revision=revision,
-        )
-        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
-
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
             num_blocks = self._choose_num_blocks()