ソースを参照

Add multibatch mode

Artem Chumachenko 3 年 前
コミット
a3be17a36e

+ 6 - 3
src/client/remote_block.py

@@ -61,12 +61,15 @@ class RemoteTransformerBlockInferenceSession:
 
     @classmethod
     async def _create(
-        cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+        cls,
+        remote_module: RemoteTransformerBlock,
+        timeout: Optional[float] = None,
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await remote_module.stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
+            cls._read_inputs_from_queue(inputs_queue, timeout),
+            timeout=timeout,
         )
         return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
 
@@ -97,7 +100,7 @@ class RemoteTransformerBlockInferenceSession:
         )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
+        return outputs
 
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

+ 4 - 8
src/client/remote_generation.py

@@ -56,8 +56,6 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
-        word_embeddings = self.transformer.word_embeddings.weight
-
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
             inputs = torch.tensor([[bos_token_id]])
@@ -79,7 +77,7 @@ class RemoteGenerationMixin:
         with self.transformer.h.inference_session() as sess:
             outputs = []
             if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
-                outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).min()]]
+                outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
             else:
                 outputs += [inputs]
             last_token_id = None
@@ -87,19 +85,17 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
-                print(embs.size())
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 hidden_state = sess.step(embs)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
-                lm_logits = F.linear(hidden_state, word_embeddings).float()
+                lm_logits = self.lm_head(hidden_state)
 
                 for constraint in constraints:
-                    print(lm_logits.size())
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
                 if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
-                    pad_token_mask = inputs[:, seq_idx] == pad_token_id
-                    last_token_id = (1 - pad_token_mask) * inputs[:, seq_idx] + pad_token_mask * last_token_id
+                    pad_token_mask = inputs[:, seq_idx:seq_idx + 1] == pad_token_id
+                    last_token_id = (~pad_token_mask) * inputs[:, seq_idx:seq_idx + 1] + pad_token_mask * last_token_id
 
                 if torch.all(last_token_id == eos_token_id):
                     break

+ 1 - 1
src/client/remote_sequential.py

@@ -141,7 +141,7 @@ class RemoteSequentialInferenceSession:
     def step(self, inputs: torch.Tensor):
         assert not self.closed
         for session in self.active_sessions:
-            outputs = session.step(inputs)
+            outputs = session.step(inputs)[0]
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs

+ 1 - 1
src/server/backend.py

@@ -23,7 +23,7 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+        self.inference_pool = TaskPool(self.inference_step, max_batch_size=4096, name=f"{self.name}_inference")
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 10 - 6
src/server/handler.py

@@ -26,7 +26,9 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(module_backend, TransformerBackend)
 
     async def rpc_inference(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+        self,
+        requests: AsyncIterator[runtime_pb2.ExpertRequest],
+        context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
@@ -35,17 +37,19 @@ class TransformerConnectionHandler(ConnectionHandler):
             requested_uids = self._check_header(request)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
+            batch_size = request.tensors[0].size[0] if request.tensors else 1
+
+            cache_metadata = torch.tensor([[-1, -1] for _ in range(batch_size)], dtype=torch.int64)  # [cache_handle, prefix_length]
             prefix_length = 0
 
-            async with self._allocate_caches(requested_backends) as cache_handles:
+            async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
-                        cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
+                        cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
                         assert (
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
@@ -213,7 +217,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
+    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []
@@ -221,7 +225,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
 
-                cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+                cache_descriptor = TensorDescriptor(size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))

+ 6 - 5
src/utils/generation_constraints.py

@@ -31,19 +31,20 @@ class MaxNewTokensConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
-    def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
+    def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
         self.max_new_tokens = max_new_tokens
         self.current_generated_tokens = None
         self.eos_token_id = eos_token_id
         self.min_logits = min_logits
 
-        self.current_generated_tokens = -(prefix == pad_token_id).sum(-1)
+        max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
+        self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
 
     def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         if tokens_id is not None:
             self.current_generated_tokens += 1
 
-        mask = (self.current_generated_tokens > self.max_new_tokens).unsqueeze(1)
+        mask = (self.current_generated_tokens >= self.max_new_tokens)
         logits += self.min_logits * mask
         logits[mask[:, 0], self.eos_token_id] = 0
         return logits
@@ -58,12 +59,12 @@ class EosConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
-    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
+    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
         self.eos_token_id = eos_token_id
         self.min_logits = min_logits
         self.past_tokens = None
 
-        self.wait_until_starting = (prefix == pad_token_id).sum(-1).unsqueeze(1)
+        self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
 
     def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         if self.past_tokens is not None: