|
@@ -38,7 +38,9 @@ async def sequential_forward(
|
|
|
"""
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
- assert isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2, f"{type(attention_masks)}: {attention_masks.ndim}"
|
|
|
+ assert (
|
|
|
+ isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2
|
|
|
+ ), f"{type(attention_masks)}: {attention_masks.ndim}"
|
|
|
|
|
|
inputs_device = inputs.device
|
|
|
inputs_dtype = inputs.dtype
|
|
@@ -202,20 +204,33 @@ async def _gather_forward(input_batches, attention_mask_batches, prompt_batches,
|
|
|
return await asyncio.gather(
|
|
|
*[
|
|
|
sequential_forward(input_batch, attention_mask_batch, prompt_batch, sequence_manager)
|
|
|
- for input_batch, attention_mask_batch, prompt_batch in zip(input_batches, attention_mask_batches, prompt_batches)
|
|
|
+ for input_batch, attention_mask_batch, prompt_batch in zip(
|
|
|
+ input_batches, attention_mask_batches, prompt_batches
|
|
|
+ )
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
async def _gather_backward(
|
|
|
- grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences, sequence_manager
|
|
|
+ grad_output_batches,
|
|
|
+ intermediate_input_batches,
|
|
|
+ attention_mask_batches,
|
|
|
+ prompt_batches,
|
|
|
+ forward_sequences,
|
|
|
+ sequence_manager,
|
|
|
):
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
|
|
return await asyncio.gather(
|
|
|
*[
|
|
|
- sequential_backward((grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager)
|
|
|
+ sequential_backward(
|
|
|
+ (grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager
|
|
|
+ )
|
|
|
for grad_output, input_batch, attention_mask_batch, prompt_batch, spans in zip(
|
|
|
- grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences
|
|
|
+ grad_output_batches,
|
|
|
+ intermediate_input_batches,
|
|
|
+ attention_mask_batches,
|
|
|
+ prompt_batches,
|
|
|
+ forward_sequences,
|
|
|
)
|
|
|
]
|
|
|
)
|
|
@@ -228,7 +243,13 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
"""
|
|
|
|
|
|
@staticmethod
|
|
|
- def forward(ctx, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
|
+ def forward(
|
|
|
+ ctx,
|
|
|
+ inputs: torch.Tensor,
|
|
|
+ attention_mask: torch.Tensor,
|
|
|
+ prompts: torch.Tensor,
|
|
|
+ sequence_manager: RemoteSequenceManager,
|
|
|
+ ):
|
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
|
|
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
|
|
attention_mask_batches: Sequence[torch.Tensor] = attention_mask.detach().split(batch_size)
|
|
@@ -238,7 +259,9 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
|
|
|
|
|
|
sequence_manager.rpc_info # lazy init
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager))
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ _gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager)
|
|
|
+ )
|
|
|
assert len(outputs) == len(input_batches)
|
|
|
|
|
|
output_batches = [output[0] for output in outputs]
|
|
@@ -261,7 +284,12 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
|
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
|
|
|
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
|
|
|
- assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) == len(attention_mask_batches)
|
|
|
+ assert (
|
|
|
+ len(intermediate_input_batches)
|
|
|
+ == len(grad_output_batches)
|
|
|
+ == len(forward_sequences)
|
|
|
+ == len(attention_mask_batches)
|
|
|
+ )
|
|
|
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
_gather_backward(
|