|
@@ -64,8 +64,19 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
- hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
|
|
|
+ hidden_states, *prompts = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+
|
|
|
+ # parse deep prompts (optional argument)
|
|
|
+ if not prompts or is_dummy(prompts[0]):
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ else:
|
|
|
+ prompts = [prompts[0].to(dtype=requested_backends[0].dtype)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+
|
|
|
+ if not (len(requested_backends) == len(prompts)):
|
|
|
+ raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
|
+
|
|
|
+ length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
|
|
|
|
if prefix_length + length_increment > max_length:
|
|
|
raise ValueError(
|
|
@@ -77,15 +88,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
- for backend, cache_handle in zip(requested_backends, cache_handles):
|
|
|
+ for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
+
|
|
|
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
+ assert isinstance(
|
|
|
+ hidden_states, torch.Tensor
|
|
|
+ ), f"hidden states must be tensor, got {type(hidden_states)}"
|
|
|
assert (
|
|
|
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
+ hidden_states.ndim == 3
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
-
|
|
|
- hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
|
|
|
- assert isinstance(hidden_states, (list, tuple))
|
|
|
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
+ (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states)
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
yield runtime_pb2.ExpertResponse(
|
|
@@ -245,16 +259,14 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
assert hidden_states.ndim == 3
|
|
|
if not prompts or is_dummy(prompts[0]):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
- pre_seq_len = 0
|
|
|
else:
|
|
|
prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
- pre_seq_len = prompts[0].shape[-2]
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
if not is_dummy(prompt):
|
|
|
- hidden_states[:, :pre_seq_len] += prompt
|
|
|
+ hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert (
|
|
@@ -275,11 +287,9 @@ async def _rpc_backward(
|
|
|
|
|
|
if not prompts or is_dummy(prompts[0]):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
- pre_seq_len = 0
|
|
|
else:
|
|
|
prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
- pre_seq_len = prompts[0].shape[-2]
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
@@ -287,13 +297,13 @@ async def _rpc_backward(
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
- inputs[:, :pre_seq_len] += prompt
|
|
|
+ inputs[:, : prompt.shape[1]] += prompt
|
|
|
inter_inputs.append(inputs)
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
|
|
|
if not is_dummy(prompts[-1]):
|
|
|
- inputs[:, :pre_seq_len] += prompts[-1]
|
|
|
+ inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
@@ -303,7 +313,7 @@ async def _rpc_backward(
|
|
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
|
- grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
+ grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|