artek0chumak 3 年 前
コミット
12d874f195

+ 2 - 2
src/client/remote_block.py

@@ -78,12 +78,12 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor):
+    def step(self, new_hidden_states: torch.Tensor, batch_ids: torch.Tensor):
         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, batch_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(

+ 66 - 0
src/client/remote_generation.py

@@ -0,0 +1,66 @@
+import torch
+
+from typing import List, Optional
+
+from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, TopKAlgorithm, NucleusAlgorithm
+from src.utils.generation_constraints import ABConstraint, MaxNewTokensConstraint
+
+from transformers.modeling_utils import PreTrainedModel
+
+
+class RemoteGenerationMixin(PreTrainedModel):
+    def generation(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        do_sample: Optional[bool] = None,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        eos_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        decoding_algorithm: Optional[DecodingAlgorithm] = None,
+        provided_constraints: List[ABConstraint] = [],
+        **model_kwargs,
+    ) -> torch.Tensor:
+        if decoding_algorithm is None:
+            if do_sample:
+                if (top_k is None) == (top_p is None):
+                    raise ValueError("You have to provide only top_k or top_p for sampling")
+                if top_k:
+                    decoding_algorithm = TopKAlgorithm(top_k, temperature)
+                elif top_p:
+                    decoding_algorithm = NucleusAlgorithm(top_p, temperature)
+            else:
+                decoding_algorithm = GreedyAlgorithm()
+        
+        constraints = []
+        constraints.extend(provided_constraints)
+                
+        if max_new_tokens and eos_token_id:
+            constraints.append(MaxNewTokensConstraint(max_new_tokens, eos_token_id))
+            
+        for constraint in constraints:
+            constraint.consume_prefix(inputs)
+       
+        word_embeddings = self.transformer.word_embeddings.weight.t()
+
+        with self.transformer.h.inference_session() as sess:
+            last_token_id = inputs[:, -1]
+            outputs = []
+            while torch.any(last_token_id != eos_token_id):
+                embs = self.transformer.word_embeddings(inputs)
+                embs = self.transformer.word_embeddings_layernorm(embs)
+                for emb_ids in range(embs.size(1)):
+                    recurrent_output = sess.step(embs[:, emb_ids:emb_ids+1])
+                recurrent_output = self.transformer.ln_f(recurrent_output)
+                lm_logits = (recurrent_output @ word_embeddings).float()
+                for constraint in constraints:
+                    lm_logits = constraint.calculate_transation(lm_logits)
+                last_token_id, _ = decoding_algorithm(lm_logits)
+                for constraint in constraints:
+                    constraint.update(last_token_id, torch.ones_like(last_token_id))
+                outputs.append(last_token_id)
+                inputs = last_token_id
+            
+        return torch.cat(outputs, dim=-1)
+

+ 6 - 3
src/client/remote_sequential.py

@@ -5,6 +5,8 @@ import logging
 import random
 from typing import Optional, Union
 
+from typing import Optional
+
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
@@ -138,11 +140,12 @@ class RemoteSequentialInferenceSession:
 
         return self
 
-    def step(self, inputs: torch.Tensor):
+    def step(self, inputs: torch.Tensor, batch_ids: torch.Tensor):
         assert not self.closed
         for session in self.active_sessions:
-            outputs = session.step(inputs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+            outputs = session.step(inputs, batch_ids)
+            assert outputs.shape[1:] == inputs.shape[1:], f"expected {inputs.shape[1:]}, got {outputs.shape[1:]}"
+            assert outputs.shape[0] == batch_ids.shape[0], f"expected {batch_ids.shape[0]}, got {outputs.shape[0]}"
             inputs = outputs
         return inputs
 

+ 1 - 1
src/server/backend.py

@@ -1,5 +1,5 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Sequence, Tuple
+from typing import Optional, Sequence, Tuple
 
 import torch
 from hivemind.moe.server.module_backend import ModuleBackend

+ 8 - 0
src/server/cache.py

@@ -122,6 +122,14 @@ class MemoryCache:
         assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
         yield self._allocated_tensors[handle]
 
+    @staticmethod
+    def update_cache_via_batch_ids(cache: torch.Tensor, batch_ids: torch.Tensor) -> None:
+        new_cache_shape = cache.shape
+        new_cache_shape[1] = batch_ids.size(0)
+        new_cache = torch.zeros(new_cache_shape)
+        new_cache.scatter_(1, batch_ids, cache)
+        cache.copy_(new_cache)
+
 
 class AllocationFailed(Exception):
     pass