Prechádzať zdrojové kódy

Let users specify sequence length instead of assuming 2048 (#52)

- Maximum length is now provided in `.inference_session(max_length=100)`
   - previously, we would always assume max length = 2048
- added a generic way to forward **kwargs to inference session
  - for compatibility with #47 
  - Note to @borzunov : it does *not* pass them arbitrarily, but instead checks for kwarg names at the bottom level
- run_server can be started with a custom max_length for inference
- renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir)
- --attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800)
- made some server-side errors more human-readable to user (e.g. when max length is exceeded)

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 rokov pred
rodič
commit
d271b75dd4

+ 1 - 1
.github/workflows/run-tests.yaml

@@ -81,7 +81,7 @@ jobs:
 
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \
-            --throughput 1 &> server1.log &
+            --throughput 1 --attn_cache_size 0.2GiB &> server1.log &
           SERVER1_PID=$!
           
           sleep 5  # wait for the first server to initialize DHT

+ 1 - 1
README.md

@@ -53,7 +53,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.inference_session() as sess:
+with layer3.inference_session(max_length=10) as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 15 - 4
cli/run_server.py

@@ -2,6 +2,7 @@ import configargparse
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from humanfriendly import parse_size
 
 from src.server.server import Server
 
@@ -32,16 +33,19 @@ def main():
     parser.add_argument('--min_batch_size', type=int, default=1,
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
-                        help='The total number of examples in the same batch will not exceed this value')
+                        help='The total number of tokens in the same batch will not exceed this value')
+    parser.add_argument('--inference_max_length', type=int, default=16384,
+                        help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
     parser.add_argument('--cache_dir', type=str, default=None, 
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
-    parser.add_argument('--cache_size_bytes', type=int, default=None,
-                        help='The size of memory cache for storing past attention keys/values between inference steps')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
+    parser.add_argument('--attn_cache_size', type=str, default=None,
+                        help='The size of GPU memory allocated for storing past attention keys/values between inference'
+                             ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
     parser.add_argument('--revision', type=str, default='main',
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -81,10 +85,17 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
+    attn_cache_size = args.pop("attn_cache_size")
+    if attn_cache_size is not None:
+        attn_cache_size = parse_size(attn_cache_size)
+    assert isinstance(
+        attn_cache_size, (int, type(None))
+    ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
+
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
-    server = Server.create(**args, start=True, compression=compression)
+    server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
 
     try:
         server.join()

+ 1 - 0
requirements.txt

@@ -1,5 +1,6 @@
 torch==1.12.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
+humanfriendly
 https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
 https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 17 - 5
src/client/inference_session.py

@@ -7,6 +7,7 @@ from typing import AsyncIterator, List, Optional
 import torch
 from hivemind import (
     P2P,
+    MSGPackSerializer,
     anext,
     deserialize_torch_tensor,
     get_logger,
@@ -33,23 +34,32 @@ class RemoteTransformerBlockInferenceSession:
     :note: this inference session is *not* fault-tolerant out of the box
     """
 
-    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+    def __init__(
+        self,
+        uid: ModuleUID,
+        rpc_info: RPCInfo,
+        inputs_queue: asyncio.Queue,
+        outputs_aiter: AsyncIterator,
+        *,
+        max_length: int,
+    ):
         self.uid, self.rpc_info = uid, rpc_info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
         self.stepped = False
         self.closed = False
 
     @classmethod
     async def _create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
-        return cls(uid, rpc_info, inputs_queue, outputs_stream)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
@@ -73,6 +83,7 @@ class RemoteTransformerBlockInferenceSession:
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
+                    metadata=self._serialized_metadata if not self.stepped else None,
                 )
             )
         )
@@ -121,13 +132,14 @@ class RemoteSequentialInferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
         self.chosen_spans: List[RemoteSpanInfo] = []
         self.stack = contextlib.ExitStack()
         self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.metadata = metadata
         self.timeout = timeout
 
     def __enter__(self):
@@ -141,7 +153,7 @@ class RemoteSequentialInferenceSession:
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             inference_session = RemoteExpertWorker.run_coroutine(
                 RemoteTransformerBlockInferenceSession._create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
                 )
             )
             self.inference_sessions.append(inference_session)

+ 2 - 6
src/client/remote_block.py

@@ -33,12 +33,8 @@ class RemoteTransformerBlock(RemoteExpert):
             assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
         return super().forward(inputs)
 
-    def inference_session(self) -> RemoteTransformerBlockInferenceSession:
+    def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         return RemoteExpertWorker.run_coroutine(
-            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
+            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
         )
-
-    def begin_inference_session(self):
-        logger.warning("beging_inference_session was renamed to just inference_session")
-        return self.inference_session()

+ 8 - 2
src/client/remote_generation.py

@@ -60,14 +60,20 @@ class RemoteGenerationMixin:
         assert (
             model_kwargs.get("stopping_criteria", None) is None
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+        if inputs is not None:
+            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
+        prefix_length = 0 if inputs is None else inputs.size(1)
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
         if max_length is not None and max_new_tokens is None:
-            max_new_tokens = max_length - inputs.size(1)
+            max_new_tokens = max_length - prefix_length
             assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+        elif max_length is None and max_new_tokens is not None:
+            max_length = prefix_length + max_new_tokens
 
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
@@ -87,7 +93,7 @@ class RemoteGenerationMixin:
             provided_constraints=provided_constraints,
         )
 
-        with self.transformer.h.inference_session() as sess:
+        with self.transformer.h.inference_session(max_length=max_length) as sess:
             outputs = []
             if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
                 outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]

+ 2 - 2
src/client/remote_sequential.py

@@ -79,9 +79,9 @@ class RemoteSequential(nn.Module):
     def __len__(self):
         return len(self.sequence_manager)
 
-    def inference_session(self) -> RemoteSequentialInferenceSession:
+    def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
         self.sequence_manager.update_()
-        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
+        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 0 - 2
src/server/backend.py

@@ -14,8 +14,6 @@ from src.server.cache import MemoryCache
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-MAX_LENGTH = 2048
-
 
 class InferenceTaskPool(TaskPool):
     def __init__(self, *args, **kwargs):

+ 24 - 6
src/server/handler.py

@@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.server.backend import TransformerBackend
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     module_backends: Dict[ModuleUID, TransformerBackend]
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
+    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
+        self.inference_max_length = inference_max_length
 
     async def rpc_inference(
         self,
@@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
             requested_uids = self._check_uids(request.uid)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            max_length = metadata.get("max_length")
+
+            if not requested_uids:
+                raise ValueError("User must specify at least one block for inference, but got none")
+            assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+            if not 0 <= max_length <= self.inference_max_length:
+                raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
@@ -52,10 +61,17 @@ class TransformerConnectionHandler(ConnectionHandler):
             )  # [cache_handle, prefix_length]
             prefix_length = 0
 
-            async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
+            async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+
+                    if prefix_length + length_increment > max_length:
+                        raise ValueError(
+                            f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                            f" exceeds pre-allocated maximum {max_length}"
+                        )
 
                     # Cast inputs to backend dtype
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@@ -113,7 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
         hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
-        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
         # Serialize the overall output
         serialized_output = [
@@ -193,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
+    async def _allocate_caches(
+        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+    ) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []
@@ -202,7 +220,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
+                    size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 

+ 7 - 3
src/server/server.py

@@ -36,6 +36,7 @@ class Server(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
+        inference_max_length: int,
         num_connection_handlers: int = 8,
         throughput: float,
         update_period: float = 30,
@@ -47,7 +48,8 @@ class Server(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+            TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
+            for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
@@ -104,10 +106,11 @@ class Server(threading.Thread):
         num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
+        inference_max_length: int = 4096,
         torch_dtype: str = "auto",
         revision: str = "main",
         cache_dir: Optional[str] = None,
-        cache_size_bytes: Optional[int] = None,
+        attn_cache_size: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
         compression=CompressionType.NONE,
@@ -141,7 +144,7 @@ class Server(threading.Thread):
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        memory_cache = MemoryCache(device, cache_size_bytes)
+        memory_cache = MemoryCache(device, attn_cache_size)
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
@@ -228,6 +231,7 @@ class Server(threading.Thread):
             blocks,
             throughput=throughput,
             num_connection_handlers=num_handlers,
+            inference_max_length=inference_max_length,
             device=device,
             stats_report_interval=stats_report_interval,
             update_period=update_period,

+ 8 - 1
tests/test_block_exact_match.py

@@ -4,6 +4,7 @@ import hivemind
 import pytest
 import torch
 import transformers
+from hivemind import P2PHandlerError
 from test_utils import *
 
 from src.bloom.from_pretrained import load_pretrained_block
@@ -27,9 +28,15 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
         (outputs_forward,) = remote_block(inputs)
 
         outputs_inference = []
-        with remote_block.inference_session() as sess:
+        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             for i in range(inputs.shape[1]):
                 outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+
+            # test that max length is respected
+            with pytest.raises(P2PHandlerError) as exc_info:
+                sess.step(inputs[:, -1:, :])
+            assert "Maximum length exceeded" in repr(exc_info.value)
+
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

+ 1 - 1
tests/test_chained_calls.py

@@ -63,7 +63,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     inputs = torch.randn(1, 8, config.hidden_size)
 
     outputs_inference = []
-    with remote_block.inference_session() as sess:
+    with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)

+ 1 - 1
tests/test_full_model.py

@@ -31,7 +31,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         embs = model.transformer.word_embeddings(test_inputs)
         embs = model.transformer.word_embeddings_layernorm(embs)
         recurrent_outputs = []
-        with model.transformer.h.inference_session() as sess:
+        with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
             for t in range(embs.shape[1]):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)