|
@@ -66,12 +66,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
|
+ # Cast inputs to backend dtype
|
|
|
+ hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
+ assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
|
|
+
|
|
|
# parse deep prompts (optional argument)
|
|
|
- if not prompts or is_dummy(prompts[0]):
|
|
|
+ if prompts is None or is_dummy(prompts) or is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
- prompts = [prompts[0].to(dtype=requested_backends[0].dtype)]
|
|
|
- prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
if not (len(requested_backends) == len(prompts)):
|
|
|
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
@@ -83,9 +86,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
f" exceeds pre-allocated maximum {max_length}"
|
|
|
)
|
|
|
|
|
|
- # Cast inputs to backend dtype
|
|
|
- hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
-
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
|
|
|
if not is_dummy(prompt):
|
|
@@ -98,7 +98,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert (
|
|
|
hidden_states.ndim == 3
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
- (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states)
|
|
|
+ (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
yield runtime_pb2.ExpertResponse(
|
|
@@ -251,16 +251,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
"""
|
|
|
- hidden_states, *prompts = flat_tensors
|
|
|
+ hidden_states, prompts = flat_tensors
|
|
|
dtype = requested_backends[0].dtype
|
|
|
# check parse input tensors and cast dtypes
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
assert hidden_states.ndim == 3
|
|
|
- if not prompts or is_dummy(prompts[0]):
|
|
|
+ if prompts is None or is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
- prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
- prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
@@ -279,16 +278,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
async def _rpc_backward(
|
|
|
*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
- inputs, grad_outputs, *prompts = flat_tensors
|
|
|
+ inputs, grad_outputs, prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
|
|
|
- if not prompts or is_dummy(prompts[0]):
|
|
|
+ if prompts is None or is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
- prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
- prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|