|
@@ -25,7 +25,7 @@ MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
async def sequential_forward(
|
|
|
inputs: torch.Tensor,
|
|
|
- attention_mask: torch.Tensor,
|
|
|
+ attention_masks: torch.Tensor,
|
|
|
prompts: torch.Tensor,
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
start_index: int = 0,
|
|
@@ -38,12 +38,12 @@ async def sequential_forward(
|
|
|
"""
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
- assert isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 2, f"{type(attention_mask)}: {attention_mask.ndim}"
|
|
|
+ assert isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2, f"{type(attention_masks)}: {attention_masks.ndim}"
|
|
|
|
|
|
inputs_device = inputs.device
|
|
|
inputs_dtype = inputs.dtype
|
|
|
inputs = inputs.cpu()
|
|
|
- attention_mask = attention_mask.cpu()
|
|
|
+ attention_masks = attention_masks.cpu()
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
@@ -71,7 +71,7 @@ async def sequential_forward(
|
|
|
span = sequences.popleft()
|
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
- inputs_and_prompts = [inputs, attention_mask, prompts[span.start : span.end]]
|
|
|
+ inputs_and_prompts = [inputs, attention_masks, prompts[span.start : span.end]]
|
|
|
|
|
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
|
|
@@ -114,7 +114,7 @@ async def sequential_forward(
|
|
|
async def sequential_backward(
|
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
|
intermediate_inputs: List[torch.Tensor],
|
|
|
- attention_mask: torch.Tensor,
|
|
|
+ attention_masks: torch.Tensor,
|
|
|
prompts: torch.Tensor,
|
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
@@ -132,7 +132,7 @@ async def sequential_backward(
|
|
|
|
|
|
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
|
|
|
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
|
|
|
- attention_mask = attention_mask.cpu()
|
|
|
+ attention_masks = attention_masks.cpu()
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
|
grad_prompts_reversed = []
|
|
@@ -144,7 +144,7 @@ async def sequential_backward(
|
|
|
try:
|
|
|
if attempt_no >= 1:
|
|
|
_, backup_inputs, backup_sequences = await sequential_forward(
|
|
|
- inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
+ inputs, attention_masks, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
)
|
|
|
assert len(backup_inputs) == len(backup_sequences)
|
|
|
assert backup_sequences[0].start == span.start
|
|
@@ -158,14 +158,14 @@ async def sequential_backward(
|
|
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
metadata = sequence_manager.get_request_metadata(
|
|
|
- "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
|
|
|
+ "rpc_backward", span_uids, *inputs, attention_masks, *grad_outputs, peer_id=span.peer_id
|
|
|
)
|
|
|
grad_outputs, *span_grad_prompts = await run_remote_backward(
|
|
|
span_uids,
|
|
|
stub,
|
|
|
sequence_manager.rpc_info,
|
|
|
inputs,
|
|
|
- attention_mask,
|
|
|
+ attention_masks,
|
|
|
grad_outputs,
|
|
|
prompts[span.start : span.end],
|
|
|
timeout=sequence_manager.request_timeout,
|
|
@@ -255,27 +255,28 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_outputs: torch.Tensor):
|
|
|
intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
|
|
|
+ attention_mask_batches: List[Sequence[torch.Tensor]] = ctx.attention_mask_batches
|
|
|
forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
|
|
|
ctx.sequence_manager.rpc_info # lazy init
|
|
|
|
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
|
|
|
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
|
|
|
- assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
|
|
|
+ assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) == len(attention_mask_batches)
|
|
|
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
_gather_backward(
|
|
|
grad_output_batches,
|
|
|
intermediate_input_batches,
|
|
|
- ctx.attention_mask_batches,
|
|
|
+ attention_mask_batches,
|
|
|
ctx.prompt_batches,
|
|
|
forward_sequences,
|
|
|
ctx.sequence_manager,
|
|
|
)
|
|
|
)
|
|
|
grad_input_batches = [output[0][0] for output in outputs]
|
|
|
- grad_prompt_batches = [output[2] for output in outputs]
|
|
|
+ grad_prompt_batches = [output[1] for output in outputs]
|
|
|
|
|
|
grad_inputs = torch.cat(grad_input_batches, dim=0)
|
|
|
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
|
|
|
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
|
|
|
- return (grad_inputs, grad_prompts, None)
|
|
|
+ return (grad_inputs, None, grad_prompts, None)
|