Browse Source

enforce max_length

justheuristic 3 years ago
parent
commit
a8d9159db6

+ 1 - 1
README.md

@@ -53,7 +53,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.inference_session() as sess:
+with layer3.inference_session(max_length=10) as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 10 - 6
src/client/inference_session.py

@@ -12,7 +12,7 @@ from hivemind import (
     get_logger,
     nested_flatten,
     serialize_torch_tensor,
-    use_hivemind_log_handler,
+    use_hivemind_log_handler, MSGPackSerializer,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
@@ -33,23 +33,25 @@ class RemoteTransformerBlockInferenceSession:
     :note: this inference session is *not* fault-tolerant out of the box
     """
 
-    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator,
+                 *, max_length: int):
         self.uid, self.rpc_info = uid, rpc_info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
         self.stepped = False
         self.closed = False
 
     @classmethod
     async def _create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
-        return cls(uid, rpc_info, inputs_queue, outputs_stream)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
@@ -73,6 +75,7 @@ class RemoteTransformerBlockInferenceSession:
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
+                    metadata=self._serialized_metadata if not self.stepped else None
                 )
             )
         )
@@ -121,13 +124,14 @@ class RemoteSequentialInferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
         self.chosen_spans: List[RemoteSpanInfo] = []
         self.stack = contextlib.ExitStack()
         self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.metadata = metadata
         self.timeout = timeout
 
     def __enter__(self):
@@ -141,7 +145,7 @@ class RemoteSequentialInferenceSession:
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             inference_session = RemoteExpertWorker.run_coroutine(
                 RemoteTransformerBlockInferenceSession._create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout,
                 )
             )
             self.inference_sessions.append(inference_session)

+ 2 - 6
src/client/remote_block.py

@@ -33,12 +33,8 @@ class RemoteTransformerBlock(RemoteExpert):
             assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
         return super().forward(inputs)
 
-    def inference_session(self) -> RemoteTransformerBlockInferenceSession:
+    def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         return RemoteExpertWorker.run_coroutine(
-            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
+            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
         )
-
-    def begin_inference_session(self):
-        logger.warning("beging_inference_session was renamed to just inference_session")
-        return self.inference_session()

+ 8 - 2
src/client/remote_generation.py

@@ -60,14 +60,20 @@ class RemoteGenerationMixin:
         assert (
             model_kwargs.get("stopping_criteria", None) is None
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+        if inputs is not None:
+            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, "inputs must be a 3d tensor [batch, len, hid]"
+        prefix_length = (0 if inputs is None else inputs.size(1))
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
         if max_length is not None and max_new_tokens is None:
-            max_new_tokens = max_length - inputs.size(1)
+            max_new_tokens = max_length - prefix_length
             assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+        elif max_length is None and max_new_tokens is not None:
+            max_length = prefix_length + max_new_tokens
+        assert max_length is not None and max_new_tokens is not None
 
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
@@ -87,7 +93,7 @@ class RemoteGenerationMixin:
             provided_constraints=provided_constraints,
         )
 
-        with self.transformer.h.inference_session() as sess:
+        with self.transformer.h.inference_session(max_length=max_length) as sess:
             outputs = []
             if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
                 outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]

+ 2 - 2
src/client/remote_sequential.py

@@ -79,9 +79,9 @@ class RemoteSequential(nn.Module):
     def __len__(self):
         return len(self.sequence_manager)
 
-    def inference_session(self) -> RemoteSequentialInferenceSession:
+    def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
         self.sequence_manager.update_()
-        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
+        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 0 - 2
src/server/backend.py

@@ -14,8 +14,6 @@ from src.server.cache import MemoryCache
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-MAX_LENGTH = 2048
-
 
 class InferenceTaskPool(TaskPool):
     def __init__(self, *args, **kwargs):

+ 20 - 6
src/server/handler.py

@@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.server.backend import TransformerBackend
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     module_backends: Dict[ModuleUID, TransformerBackend]
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
+    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
+        self.inference_max_length = inference_max_length
 
     async def rpc_inference(
         self,
@@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
             requested_uids = self._check_uids(request.uid)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            max_length = metadata.get("max_length")
+
+            if not requested_uids:
+                raise ValueError("User must specify at least one block for inference, but got none")
+            assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
+            if max_length not in range(0, self.inference_max_length):
+                raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
@@ -52,10 +61,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             )  # [cache_handle, prefix_length]
             prefix_length = 0
 
-            async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
+            async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+
+                    if prefix_length + length_increment > max_length:
+                        raise ValueError(f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                                         f" exceeds pre-allocated maximum {max_length}")
 
                     # Cast inputs to backend dtype
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@@ -113,7 +127,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
         hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
-        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
         # Serialize the overall output
         serialized_output = [
@@ -193,7 +207,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
+    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []
@@ -202,7 +216,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
+                    size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 

+ 6 - 1
src/server/server.py

@@ -36,6 +36,7 @@ class Server(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
+        inference_max_length: int,
         num_connection_handlers: int = 8,
         throughput: float,
         update_period: float = 30,
@@ -47,7 +48,8 @@ class Server(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+            TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
+            for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
@@ -104,6 +106,7 @@ class Server(threading.Thread):
         num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
+        inference_max_length: Optional[int] = None,
         torch_dtype: str = "auto",
         revision: str = "main",
         cache_dir: Optional[str] = None,
@@ -135,6 +138,8 @@ class Server(threading.Thread):
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+        if inference_max_length is None:
+            inference_max_length = max_batch_size
 
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]

+ 1 - 1
tests/test_block_exact_match.py

@@ -27,7 +27,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
         (outputs_forward,) = remote_block(inputs)
 
         outputs_inference = []
-        with remote_block.inference_session() as sess:
+        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             for i in range(inputs.shape[1]):
                 outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
         outputs_inference = torch.cat(outputs_inference, dim=1)

+ 1 - 1
tests/test_chained_calls.py

@@ -63,7 +63,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     inputs = torch.randn(1, 8, config.hidden_size)
 
     outputs_inference = []
-    with remote_block.inference_session() as sess:
+    with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)

+ 1 - 1
tests/test_full_model.py

@@ -31,7 +31,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         embs = model.transformer.word_embeddings(test_inputs)
         embs = model.transformer.word_embeddings_layernorm(embs)
         recurrent_outputs = []
-        with model.transformer.h.inference_session() as sess:
+        with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
             for t in range(embs.shape[1]):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)