ソースを参照

s/expert/block/g

justheuristic 2 年 前
コミット
97dd3c874a
3 ファイル変更6 行追加6 行削除
  1. 3 3
      cli/run_server.py
  2. 1 1
      src/server/backend.py
  3. 2 2
      src/server/server.py

+ 3 - 3
cli/run_server.py

@@ -31,7 +31,7 @@ def main():
     parser.add_argument('--num_handlers', type=int, default=8, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
-                        help='Minimum required batch size for all expert operations')
+                        help='Minimum required batch size for all operations (in total tokens)')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of tokens in the same batch will not exceed this value')
     parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
@@ -43,7 +43,7 @@ def main():
     parser.add_argument('--cache_dir', type=str, default=None, 
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--device', type=str, default=None, required=False,
-                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
+                        help='all blocks will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
@@ -62,7 +62,7 @@ def main():
                              'on the first run and uses these estimates for future runs. '
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
     parser.add_argument('--update_period', type=float, required=False, default=30,
-                        help='Server will report experts to DHT once in this many seconds')
+                        help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],

+ 1 - 1
src/server/backend.py

@@ -80,5 +80,5 @@ class TransformerBackend(ModuleBackend):
         return self.forward_pool, self.backward_pool, self.inference_pool
 
     def get_info(self) -> Dict[str, Any]:
-        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         return dict(super().get_info(), inference_schema=self.inference_schema)

+ 2 - 2
src/server/server.py

@@ -71,9 +71,9 @@ class Server(threading.Thread):
         runs Runtime (self.runtime) to process incoming requests.
         """
         logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for expert_name, backend in self.module_backends.items():
+        for block_name, backend in self.module_backends.items():
             num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+            logger.info(f"{block_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
 
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)