|
@@ -25,6 +25,7 @@ MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
async def sequential_forward(
|
|
|
inputs: torch.Tensor,
|
|
|
+ attention_mask: torch.Tensor,
|
|
|
prompts: torch.Tensor,
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
start_index: int = 0,
|
|
@@ -37,10 +38,12 @@ async def sequential_forward(
|
|
|
"""
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
+ assert isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 2, f"{type(attention_mask)}: {attention_mask.ndim}"
|
|
|
|
|
|
inputs_device = inputs.device
|
|
|
inputs_dtype = inputs.dtype
|
|
|
inputs = inputs.cpu()
|
|
|
+ attention_mask = attention_mask.cpu()
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
@@ -68,7 +71,7 @@ async def sequential_forward(
|
|
|
span = sequences.popleft()
|
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
- inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
+ inputs_and_prompts = [inputs, attention_mask, prompts[span.start : span.end]]
|
|
|
|
|
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
|
|
@@ -111,6 +114,7 @@ async def sequential_forward(
|
|
|
async def sequential_backward(
|
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
|
intermediate_inputs: List[torch.Tensor],
|
|
|
+ attention_mask: torch.Tensor,
|
|
|
prompts: torch.Tensor,
|
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
@@ -128,6 +132,7 @@ async def sequential_backward(
|
|
|
|
|
|
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
|
|
|
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
|
|
|
+ attention_mask = attention_mask.cpu()
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
|
grad_prompts_reversed = []
|
|
@@ -160,6 +165,7 @@ async def sequential_backward(
|
|
|
stub,
|
|
|
sequence_manager.rpc_info,
|
|
|
inputs,
|
|
|
+ attention_mask,
|
|
|
grad_outputs,
|
|
|
prompts[span.start : span.end],
|
|
|
timeout=sequence_manager.request_timeout,
|
|
@@ -191,25 +197,25 @@ async def sequential_backward(
|
|
|
return grad_outputs, grad_prompts
|
|
|
|
|
|
|
|
|
-async def _gather_forward(input_batches, prompt_batches, sequence_manager):
|
|
|
+async def _gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager):
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
|
|
|
return await asyncio.gather(
|
|
|
*[
|
|
|
- sequential_forward(input_batch, prompt_batch, sequence_manager)
|
|
|
- for input_batch, prompt_batch in zip(input_batches, prompt_batches)
|
|
|
+ sequential_forward(input_batch, attention_mask_batch, prompt_batch, sequence_manager)
|
|
|
+ for input_batch, attention_mask_batch, prompt_batch in zip(input_batches, attention_mask_batches, prompt_batches)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
async def _gather_backward(
|
|
|
- grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
|
|
|
+ grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences, sequence_manager
|
|
|
):
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
|
|
return await asyncio.gather(
|
|
|
*[
|
|
|
- sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
|
|
|
- for grad_output, input_batch, prompt_batch, spans in zip(
|
|
|
- grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
|
|
|
+ sequential_backward((grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager)
|
|
|
+ for grad_output, input_batch, attention_mask_batch, prompt_batch, spans in zip(
|
|
|
+ grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences
|
|
|
)
|
|
|
]
|
|
|
)
|
|
@@ -222,16 +228,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
"""
|
|
|
|
|
|
@staticmethod
|
|
|
- def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
|
+ def forward(ctx, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
|
|
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
|
|
+ attention_mask_batches: Sequence[torch.Tensor] = attention_mask.detach().split(batch_size)
|
|
|
if is_dummy(prompts):
|
|
|
prompt_batches = [DUMMY] * len(input_batches)
|
|
|
else:
|
|
|
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
|
|
|
|
|
|
sequence_manager.rpc_info # lazy init
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager))
|
|
|
assert len(outputs) == len(input_batches)
|
|
|
|
|
|
output_batches = [output[0] for output in outputs]
|
|
@@ -241,6 +248,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
ctx.prompt_batches = prompt_batches
|
|
|
ctx.sequence_manager = sequence_manager
|
|
|
ctx.intemediate_input_batches = intemediate_input_batches
|
|
|
+ ctx.attention_mask_batches = attention_mask_batches
|
|
|
ctx.sequences_for_batches = sequences_for_batches
|
|
|
return torch.cat(output_batches, dim=0)
|
|
|
|
|
@@ -258,13 +266,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
_gather_backward(
|
|
|
grad_output_batches,
|
|
|
intermediate_input_batches,
|
|
|
+ ctx.attention_mask_batches,
|
|
|
ctx.prompt_batches,
|
|
|
forward_sequences,
|
|
|
ctx.sequence_manager,
|
|
|
)
|
|
|
)
|
|
|
grad_input_batches = [output[0][0] for output in outputs]
|
|
|
- grad_prompt_batches = [output[1] for output in outputs]
|
|
|
+ grad_prompt_batches = [output[2] for output in outputs]
|
|
|
|
|
|
grad_inputs = torch.cat(grad_input_batches, dim=0)
|
|
|
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
|