Browse Source

basic backend

justheuristic 3 years ago
parent
commit
1c49bcb741

+ 77 - 0
cli/run_server.py

@@ -0,0 +1,77 @@
+import os, sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  # add path to src
+
+import configargparse
+
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from src.server.server import BloomServer
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+def main():
+    # fmt:off
+    parser = configargparse.ArgParser(default_config_files=["config.yml"])
+    parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
+
+    parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
+    parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
+    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
+
+    parser.add_argument('--num_handlers', type=int, default=None, required=False,
+                        help='server will use this many processes to handle incoming requests')
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all expert operations')
+    parser.add_argument('--max_batch_size', type=int, default=16384,
+                        help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--cache_size_bytes', type=int, default=None,
+                        help='The size of memory cache for storing past attention keys/values between inference steps')
+    parser.add_argument('--device', type=str, default=None, required=False,
+                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
+
+    parser.add_argument('--update_period', type=float, required=False, default=30,
+                        help='Server will report experts to DHT once in this many seconds')
+    parser.add_argument('--expiration', type=float, required=False, default=None,
+                        help='DHT entries will expire after this many seconds')
+    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
+                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
+    parser.add_argument('--increase_file_limit', action='store_true',
+                        help='On *nix, this will increase the max number of processes '
+                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--stats_report_interval', type=int, required=False,
+                        help='Interval between two reports of batch processing performance statistics')
+
+    parser.add_argument('--custom_module_path', type=str, required=False,
+                        help='Path of a file with custom nn.modules, wrapped into special decorator')
+
+    # fmt:on
+    args = vars(parser.parse_args())
+    args.pop("config", None)
+
+    if args.pop("increase_file_limit"):
+        increase_file_limit()
+
+    compression_type = args.pop("compression")
+    compression = getattr(CompressionType, compression_type)
+
+    server = BloomServer.create(**args, start=True, compression=compression)
+
+    try:
+        server.join()
+    except KeyboardInterrupt:
+        logger.info("Caught KeyboardInterrupt, shutting down")
+    finally:
+        server.shutdown()
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomModel, BloomForCausalLM, MemoryEfficientBloomConfig
+from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig

+ 4 - 2
src/bloom/model.py

@@ -8,6 +8,7 @@ from typing import Tuple
 
 import torch
 import torch.utils.checkpoint
+from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
 from transformers.file_utils import (
@@ -23,6 +24,7 @@ from transformers.utils import logging
 from src.bloom.block import BloomBlock
 from src.bloom.ops import build_alibi_tensor
 
+use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__name__)
 
 _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
@@ -30,7 +32,7 @@ _CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
-class MemoryEfficientBloomConfig(_VanillaBloomConfig):
+class DistributedBloomConfig(_VanillaBloomConfig):
     compression: str = "none"
     slow_but_exact: bool = False
 
@@ -42,7 +44,7 @@ class BloomPreTrainedModel(PreTrainedModel):
     models.
     """
 
-    config_class = MemoryEfficientBloomConfig
+    config_class = DistributedBloomConfig
     base_model_prefix = "transformer"
     supports_gradient_checkpointing = True
     _no_split_modules = ["BloomBlock"]

+ 0 - 73
src/node/backend.py

@@ -1,73 +0,0 @@
-"""Code for serving bloom blocks via hivemind-server"""
-import threading
-from typing import AsyncIterator, Tuple, Optional
-
-import torch
-from hivemind import P2PContext, DHT
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.dht_handler import DHTHandlerThread
-from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.runtime import Runtime
-from hivemind.moe.server.server import Server
-from hivemind.proto import runtime_pb2
-from torch import nn
-
-from src.node.cache import AttentionCache
-
-
-class BloomServer(Server):
-    """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
-    def __init__(
-            self, dht: DHT, device=torch.device, num_connection_handlers: int = 8, update_period: int = 30,
-            attention_cache_size: Optional[int] = None, start=False, **kwargs,
-    ):
-        threading.Thread.__init__(self)
-        self.attention_cache = AttentionCache(attention_cache_size, dtype=torch.bfloat16, device=torch.)
-        expert_blocks = dict(LOAD_BLOOM_LAYERS_HERE)
-
-        expert_backends = {name: _BloomBlockBackend(name, block, ..., self.attention_kv_cache) for name, block in expert_blocks.items()}
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
-        self.runtime = Runtime(self.experts, **kwargs)
-        self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
-        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-
-class _BloomConnectionHandler(ConnectionHandler):
-    """Handles three request types: forward, backward and forward-incremental (inference)"""
-
-    async def rpc_forward_incremental(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
-        # encode expert_uid as @model_name[starting_layer:finishing_layer]
-        # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
-        # - receive and maintain a handle for attention cache here
-
-        raise NotImplementedError()
-
-
-class _BloomBlockBackend(ExpertBackend):
-    def __init__(self, name: str, expert: nn.Module, *, attention_cache: AttentionCache, **kwargs):
-        self.attention_cache = attention_cache
-        super().__init__(name, expert, **kwargs)
-        #TODO
-        # BloomBackend serves a single layer
-        # - ensure that parameters do not require grad!
-        # - ensure that TaskPool for inference is NOT batched
-        # - ensure that optimizer/scheduler is not created
-
-    def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
-        with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
-            raise NotImplementedError("TODO")
-
-# later:
-# - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
-# - contiguous attention cache with max size
-# - select a subset of experts
-# - priorities
-# - option to backtrack a few tokens
-# - ensure that backprop is performed optimally, does not accumulate grads wrt parameters
-# - forget about length-adaptive forward/backward for now, use fixed length, maybe several fixed lengths - or better yet, forget finetuning for now

+ 0 - 0
src/node/__init__.py → src/server/__init__.py


+ 53 - 0
src/server/backend.py

@@ -0,0 +1,53 @@
+"""Code for serving bloom blocks via hivemind-server"""
+from typing import Tuple
+
+import torch
+from hivemind import BatchTensorDescriptor
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
+
+from src.bloom.block import BloomBlock
+from src.server.cache import MemoryCache
+
+
+# TODO
+# BloomBackend serves a single layer
+# - ensure that parameters do not require grad!
+# - ensure that TaskPool for inference is NOT batched
+# - ensure that optimizer/scheduler is not created
+
+HARDCODCED_LENGTH = 2048
+
+
+class BloomBlockBackend(ExpertBackend):
+    """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+    def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
+        object().__init__()  # to bypass super.__init__
+        self.name, self.module = name, module
+        self.memory_cache = memory_cache
+
+        for name, param in module.named_parameters():
+            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+        for name, buf in module.named_buffers():
+            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+
+        self.args_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
+        self.kwargs_schema = {}
+        self.outputs_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
+
+        self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
+        self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
+
+        self.grad_inputs_schema = self.forward_schema  # outputs from backward have same shape as inputs for forward
+        self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
+        self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
+
+    @property
+    def expert(self):
+        #TODO un-hardcode this naming from hivemind
+        return self.module
+
+    def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
+        with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
+            raise NotImplementedError("TODO")
+

+ 4 - 1
src/node/cache.py → src/server/cache.py

@@ -5,6 +5,7 @@ For now, the only purpose of this code is to ensure that allocated memory will b
 
 TODO In future, one could modify cache to implement, among other things,
 - in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
+-- note: this can be done using mp.Condtion
 - allocate cache as one contigous buffer to avoid fragmentation
 - quantize cached values using bitsandbytes
 - LRU offloading from gpu to ram
@@ -18,9 +19,11 @@ from typing import Dict, Optional, Union
 
 import hivemind
 import torch
+from hivemind import use_hivemind_log_handler
 from hivemind.utils import TensorDescriptor, get_logger
 
-logger = get_logger(__file__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
 Handle = int
 

+ 18 - 0
src/server/handler.py

@@ -0,0 +1,18 @@
+from typing import AsyncIterator
+
+from hivemind import P2PContext
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.proto import runtime_pb2
+
+
+class BloomConnectionHandler(ConnectionHandler):
+    """Handles three request types: forward, backward and forward-incremental (inference)"""
+
+    async def rpc_forward_incremental(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        # encode expert_uid as @model_name[starting_layer:finishing_layer]
+        # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
+        # - receive and maintain a handle for attention cache here
+
+        raise NotImplementedError()

+ 108 - 0
src/server/server.py

@@ -0,0 +1,108 @@
+import threading
+from typing import Optional, Dict, Union, Sequence
+
+import torch
+from hivemind import Server, DHT
+from hivemind.moe.server.dht_handler import DHTHandlerThread
+from hivemind.moe.server.layers import add_custom_models_from_file
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+
+from src import DistributedBloomConfig
+from src.bloom.block import BloomBlock
+from src.server.cache import MemoryCache
+from src.server.backend import BloomBlockBackend
+from src.server.handler import BloomConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class BloomServer(Server):
+    """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
+    def __init__(
+            self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
+            device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
+            cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
+    ):
+        threading.Thread.__init__(self)
+        self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
+
+        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
+        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
+        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+            cls,
+            num_blocks: int,
+            block_config: str,
+            num_handlers: Optional[int] = None,
+            min_batch_size: int = 1,
+            max_batch_size: int = 4096,
+            cache_size_bytes: Optional[int] = None,
+            device: Union[str, torch.device] = None,
+            initial_peers: Sequence[str] = (),
+            compression=CompressionType.NONE,
+            stats_report_interval: Optional[int] = None,
+            custom_module_path=None,
+            update_period: float = 30,
+            expiration: Optional[float] = None,
+            *,
+            start: bool,
+            **kwargs,
+    ) -> Server:
+        """Create a server with one or more bloom blocks. See run_server.py for documentation."""
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        if isinstance(block_config, str):
+            block_config = DistributedBloomConfig
+
+        # initialize modules
+        module_backends = {}
+        for i in range(len(module_backends)):
+            module_uid = f"dummy_block.{i}"
+            block = BloomBlock(block_config, layer_number=i)
+            #TODO run the actual model
+
+            module_backends[module_uid] = BloomBlockBackend(
+                name=expert_uid,
+                expert=block,
+                args_schema=args_schema,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
+
+        return cls(
+            dht,
+            experts,
+            cache_size_bytes=cache_size_bytes,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
+            start=start,
+        )
+