Artem Chumachenko 2 ani în urmă
părinte
comite
1b28432533

+ 5 - 2
src/client/inference_session.py

@@ -90,8 +90,11 @@ class RemoteTransformerBlockInferenceSession:
             assert prompts.shape[2] <= new_hidden_states.shape[1]
             assert prompts.shape[3] == new_hidden_states.shape[2]
 
-        assert hypo_ids is None, "TODO implement hypo_ids here"
-        hypo_ids = torch.arange(len(new_hidden_states))
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = torch.arange(len(new_hidden_states))
+        else:
+            assert len(hypo_ids) == len(new_hidden_states)
+            assert hypo_ids.dtype == torch.int64
 
         # serialize inputs and put them into the queue
         inputs = (new_hidden_states, prompts, hypo_ids)

+ 1 - 1
src/client/remote_generation.py

@@ -110,7 +110,7 @@ class RemoteGenerationMixin:
                     prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs, prompts=intermediate_prompts)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 

+ 5 - 13
src/server/backend.py

@@ -92,23 +92,15 @@ class TransformerBackend(ModuleBackend):
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            (
-                hidden_states,
-                hypo_ids,
-                prompts,
-            ) = inputs  # todo: in future, it would be best to support attention mask here
-            assert (
-                hidden_states.ndim == 3
-            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+            (hidden_states, hypo_ids) = inputs
+            assert (hidden_states.ndim == 3), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                arange = torch.arange(prefix_length)
-                layer_past = past_k, past_v = cache[0, hypo_ids, arange], cache[1, hypo_ids, arange]
+                cache[:, :] = cache[:, hypo_ids]
+                layer_past = past_k, past_v = cache[0], cache[1]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
-                hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True, prompts=prompts
-                )
+                hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 
                 # todo remove these asserts once we pass all tests
                 new_length = new_v.shape[1]

+ 13 - 15
src/server/handler.py

@@ -66,12 +66,15 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = hidden_states.to(requested_backends[0].dtype)
+                    assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
                     # parse deep prompts (optional argument)
-                    if not prompts or is_dummy(prompts[0]):
+                    if prompts is None or is_dummy(prompts) or is_dummy(prompts):
                         prompts = [DUMMY] * len(requested_backends)
                     else:
-                        prompts = [prompts[0].to(dtype=requested_backends[0].dtype)]
-                        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+                        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
                     if not (len(requested_backends) == len(prompts)):
                         raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
@@ -83,9 +86,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                             f" exceeds pre-allocated maximum {max_length}"
                         )
 
-                    # Cast inputs to backend dtype
-                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
                     # run request tensors through all requested modules, update caches
                     for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
                         if not is_dummy(prompt):
@@ -98,7 +98,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert (
                             hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states)
+                        (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
@@ -251,16 +251,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    hidden_states, *prompts = flat_tensors
+    hidden_states, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
@@ -279,16 +278,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
 async def _rpc_backward(
     *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, *prompts = flat_tensors
+    inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output