Sfoglia il codice sorgente

Merge branch 'generation-inference' into deep_prompt_inference

justheuristic 3 anni fa
parent
commit
e0bb3762b4

+ 3 - 3
src/client/inference_session.py

@@ -69,19 +69,19 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor):
+    def step(self, new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None):
         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, prompts, torch.arange(len(new_hidden_states)))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )

+ 3 - 2
src/client/remote_generation.py

@@ -105,11 +105,12 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 

+ 49 - 0
src/client/remote_model.py

@@ -151,6 +151,55 @@ class DistributedBloomModel(BloomModel):
         )
 
 
+class DistributedBloomPrefix(DistributedBloomModel):
+    """DistributedBloomModel with prefix tokens for prompt tuning"""
+
+    def __init__(self, config):
+        super().__init__(config)
+        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
+        self.prefix_length = config.num_prefix_tokens
+
+        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.prefix_length).long()
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+        return prompts
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert (
+            input_ids is None or inputs_embeds is None
+        ), "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        batch_size = inputs_embeds.shape[0]
+
+        if attention_mask is not None:
+            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
+            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+        prompts = self.get_prompt(batch_size)
+        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
+
+        # Remove prefix
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
+        transformer_outputs["last_hidden_state"] = last_hidden_state
+        return transformer_outputs
+
+
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 

+ 50 - 5
src/server/backend.py

@@ -1,9 +1,13 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Optional, Sequence, Tuple
+<<<<<<< HEAD
+from typing import Sequence, Tuple, Dict, Any, Optional
+=======
+from typing import Sequence, Tuple, Dict, Any
+>>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
 
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import InvalidStateError, get_logger
@@ -14,6 +18,34 @@ from src.server.cache import MemoryCache
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
+<<<<<<< HEAD
+
+class InferenceTaskPool(TaskPool):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
+
+    def iterate_minibatches(self, *args, **kwargs):
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+
+        while True:
+            try:
+                logger.debug(f"{self.name} getting next task")
+                task = self.tasks.get(timeout=self.timeout)
+            except Empty:
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
+
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    yield [task]
+            except InvalidStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+=======
+MAX_LENGTH = 2048
+>>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
+
 
 class InferenceTaskPool(TaskPool):
     def __init__(self, *args, **kwargs):
@@ -55,22 +87,31 @@ class TransformerBackend(ModuleBackend):
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            hidden_states, hypo_ids, prompts = inputs  # todo: in future, it would be best to support attention mask here
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+                arange = torch.arange(prefix_length)
+                layer_past = past_k, past_v = cache[0, hypo_ids, arange], cache[1, hypo_ids, arange]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True
+                    hidden_states, layer_past=layer_past, use_cache=True, prompts=prompts
                 )
 
                 # todo remove these asserts once we pass all tests
@@ -85,3 +126,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 1 - 2
src/server/handler.py

@@ -64,7 +64,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states, *prompts = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
                     # parse deep prompts (optional argument)
                     if not prompts or is_dummy(prompts[0]):
@@ -77,7 +77,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                         raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
 
                     length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
-
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
                             f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"