Artem Chumachenko %!s(int64=3) %!d(string=hai) anos
pai
achega
fb142375cb
Modificáronse 1 ficheiros con 26 adicións e 16 borrados
  1. 26 16
      src/server/handler.py

+ 26 - 16
src/server/handler.py

@@ -64,8 +64,19 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+                    hidden_states, *prompts = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
+                    # parse deep prompts (optional argument)
+                    if not prompts or is_dummy(prompts[0]):
+                        prompts = [DUMMY] * len(requested_backends)
+                    else:
+                        prompts = [prompts[0].to(dtype=requested_backends[0].dtype)]
+                        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
+
+                    if not (len(requested_backends) == len(prompts)):
+                        raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                    length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
 
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
@@ -77,15 +88,18 @@ class TransformerConnectionHandler(ConnectionHandler):
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
 
                     # run request tensors through all requested modules, update caches
-                    for backend, cache_handle in zip(requested_backends, cache_handles):
+                    for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                        if not is_dummy(prompt):
+                            hidden_states[:, : prompt.shape[1]] += prompt
+
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                        assert isinstance(
+                            hidden_states, torch.Tensor
+                        ), f"hidden states must be tensor, got {type(hidden_states)}"
                         assert (
-                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                            hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
-                        assert isinstance(hidden_states, (list, tuple))
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                        (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states)
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
@@ -245,16 +259,14 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     assert hidden_states.ndim == 3
     if not prompts or is_dummy(prompts[0]):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
         prompts = [prompts[0].to(requested_backends[0].dtype)]
         prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
-            hidden_states[:, :pre_seq_len] += prompt
+            hidden_states[:, : prompt.shape[1]] += prompt
         (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -275,11 +287,9 @@ async def _rpc_backward(
 
     if not prompts or is_dummy(prompts[0]):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
         prompts = [prompts[0].to(requested_backends[0].dtype)]
         prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
@@ -287,13 +297,13 @@ async def _rpc_backward(
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs[:, :pre_seq_len] += prompt
+            inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
-        inputs[:, :pre_seq_len] += prompts[-1]
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
     inter_inputs.append(inputs)
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
@@ -303,7 +313,7 @@ async def _rpc_backward(
         (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape