Pārlūkot izejas kodu

Add deep prompt inference (#66)

Add deep prompt in inference_step. Small refactoring in deep prompt code.
Artem Chumachenko 3 gadi atpakaļ
vecāks
revīzija
ada98a1b37

+ 36 - 6
src/client/inference_session.py

@@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
         max_length: int,
     ):
         self.uid, self.rpc_info = uid, rpc_info
+        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+    def step(
+        self,
+        new_hidden_states: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+    ):
+        """
+        Inference step: send a chunk of input tesors and receive a chunk of outputs
+        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+        """
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+            assert prompts.shape[2] <= new_hidden_states.shape[1]
+            assert prompts.shape[3] == new_hidden_states.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY
+        else:
+            assert len(hypo_ids) == len(new_hidden_states)
+            assert hypo_ids.dtype == torch.int64
+
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )
@@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession:
 
         return self
 
-    def step(self, inputs: torch.Tensor):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
         assert not self.closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
         for session in self.inference_sessions:
-            outputs = session.step(inputs)
+            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs

+ 3 - 2
src/client/remote_generation.py

@@ -105,11 +105,12 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 

+ 18 - 3
src/server/backend.py

@@ -1,15 +1,16 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Optional, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
+from src.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -55,18 +56,28 @@ class TransformerBackend(ModuleBackend):
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            (hidden_states, hypo_ids) = inputs
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                if not is_dummy(hypo_ids):
+                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
@@ -85,3 +96,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 38 - 29
src/server/handler.py

@@ -64,41 +64,56 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+                    hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = hidden_states.to(requested_backends[0].dtype)
+                    assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+                    # parse deep prompts (optional argument)
+                    if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+                        prompts = [DUMMY] * len(requested_backends)
+                    else:
+                        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+                    if not (len(requested_backends) == len(prompts)):
+                        raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                    length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
                             f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                             f" exceeds pre-allocated maximum {max_length}"
                         )
 
-                    # Cast inputs to backend dtype
-                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
                     # run request tensors through all requested modules, update caches
-                    for backend, cache_handle in zip(requested_backends, cache_handles):
+                    for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                        if not is_dummy(prompt):
+                            hidden_states[:, : prompt.shape[1]] += prompt
+
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                        assert isinstance(
+                            hidden_states, torch.Tensor
+                        ), f"hidden states must be tensor, got {type(hidden_states)}"
                         assert (
-                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                            hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
-                        assert isinstance(hidden_states, (list, tuple))
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                        (hidden_states,) = await backend.inference_pool.submit_task(
+                            cache_metadata, hidden_states, hypo_ids
+                        )
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
                             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
-                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
                             )
                         ]
                     )
 
                     # prepare for next step
-                    prefix_length += hidden_states[0].shape[1]
+                    prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")
@@ -238,23 +253,20 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    hidden_states, *prompts = flat_tensors
+    hidden_states, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
-            hidden_states[:, :pre_seq_len] += prompt
+            hidden_states[:, : prompt.shape[1]] += prompt
         (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -268,18 +280,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
 async def _rpc_backward(
     *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, *prompts = flat_tensors
+    inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
@@ -287,13 +296,13 @@ async def _rpc_backward(
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs[:, :pre_seq_len] += prompt
+            inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
-        inputs[:, :pre_seq_len] += prompts[-1]
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
     inter_inputs.append(inputs)
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
@@ -303,7 +312,7 @@ async def _rpc_backward(
         (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape