Răsfoiți Sursa

switch to hivemind-master

justheuristic 3 ani în urmă
părinte
comite
20497f81d1
2 a modificat fișierele cu 35 adăugiri și 12 ștergeri
  1. 6 4
      cli/run_server.py
  2. 29 8
      src/server/server.py

+ 6 - 4
cli/run_server.py

@@ -15,11 +15,13 @@ def main():
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
-    parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
-    parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
-    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+    parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
+    parser.add_argument('--block_config', type=str, default='bigscience/bloom-6b3', help="name or path of model config")
+    parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
+    parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
+    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
                         help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
-    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+    parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
                         help='Visible multiaddrs the host announces for external connections from other p2p instances')
 
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')

+ 29 - 8
src/server/server.py

@@ -11,7 +11,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import use_hivemind_log_handler, get_logger
 import multiprocessing as mp
 
-from src import DistributedBloomConfig
+from src import DistributedBloomConfig, BloomForCausalLM
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
 from src.server.backend import TransformerBackend
@@ -81,8 +81,10 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        num_blocks: int,
+        prefix: str,
         block_config: str,
+        num_blocks: Optional[int] = None,
+        block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
@@ -101,20 +103,37 @@ class Server(threading.Thread):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
-
+        assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
-        num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
         memory_cache = MemoryCache(device, cache_size_bytes)
+
+        if block_indices is not None:
+            try:
+                start, end = block_indices.split(':')
+                start, end = map(int, map(str.strip, (start, end)))
+            except Exception as e:
+                logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
+                raise
+            block_indices = range(start, end)
+        else:
+            assert num_blocks is not None
+            block_indices = range(num_blocks) # TODO replace with proper load balancing
+
+        ## TODO: the code below will load the entire model in RAM. Please replace with sliced model
+        block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
+        # model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
+        ## /TODO
+
+
         # initialize modules
         blocks = {}
-        for i in range(num_blocks):
-            module_uid = f"dummy_block.{i}"
-            block = BloomBlock(block_config, layer_number=i)
+        for block_index in block_indices:
+            module_uid = f"{prefix}.{block_index}"
+            block = BloomBlock(block_config, layer_number=block_index)
             for param in block.parameters():
                 param.requires_grad = False
 
@@ -129,6 +148,8 @@ class Server(threading.Thread):
                 max_batch_size=max_batch_size,
             )
 
+        num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
+
         return cls(
             dht,
             blocks,