|
@@ -12,6 +12,7 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
+from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
@@ -57,7 +58,7 @@ async def run_expert_backward(
|
|
|
uid: ModuleUID,
|
|
|
stub: StubBase,
|
|
|
rpc_info: RPCInfo,
|
|
|
- intemediate_inputs: List[torch.Tensor],
|
|
|
+ inputs: List[torch.Tensor],
|
|
|
grad_outputs: List[torch.Tensor],
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
"""
|
|
@@ -67,7 +68,7 @@ async def run_expert_backward(
|
|
|
"""
|
|
|
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
- inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
|
|
|
+ inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
|
|
|
|
|
|
# Asynchronous serialization
|
|
@@ -84,7 +85,11 @@ async def run_expert_backward(
|
|
|
|
|
|
|
|
|
async def sequential_forward(
|
|
|
- inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
|
|
|
+ inputs: torch.Tensor,
|
|
|
+ prompts: torch.Tensor,
|
|
|
+ sequence_manager: RemoteSequenceManager,
|
|
|
+ start_index: int = 0,
|
|
|
+ end_index: Optional[int] = None,
|
|
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
|
"""
|
|
|
Constructs a routing path from <start_index> to <end_index>.
|
|
@@ -96,6 +101,8 @@ async def sequential_forward(
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
+ if not is_dummy(prompts):
|
|
|
+ assert len(prompts) == end_index - start_index + 1
|
|
|
|
|
|
sequences = sequence_manager.make_sequence(start_index, end_index)
|
|
|
intermediate_inputs = []
|
|
@@ -107,7 +114,9 @@ async def sequential_forward(
|
|
|
span = sequences.pop(0)
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
- (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
|
|
|
+ inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
+
|
|
|
+ (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
|
|
|
|
|
|
assert isinstance(outputs, torch.Tensor)
|
|
|
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
|
@@ -119,7 +128,7 @@ async def sequential_forward(
|
|
|
inputs = outputs
|
|
|
break
|
|
|
except Exception as e:
|
|
|
- logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
+ logging.warn(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
backup_sequences = sequence_manager.make_sequence(span.start)
|
|
|
assert backup_sequences[0].start == span.start
|
|
|
sequences = backup_sequences
|
|
@@ -130,6 +139,7 @@ async def sequential_forward(
|
|
|
async def sequential_backward(
|
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
|
intermediate_inputs: Sequence[torch.Tensor],
|
|
|
+ prompts: Sequence[torch.Tensor],
|
|
|
forward_sequences: Sequence[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
) -> Sequence[torch.Tensor]:
|
|
@@ -137,10 +147,9 @@ async def sequential_backward(
|
|
|
Performs chained backward for each forward subsequence.
|
|
|
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
|
|
|
"""
|
|
|
-
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
- # TODO think about grads w.r.t. deep prompts
|
|
|
|
|
|
+ grad_prompts = []
|
|
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
|
|
while True:
|
|
|
try:
|
|
@@ -150,37 +159,50 @@ async def sequential_backward(
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
|
|
|
- grad_outputs = await run_expert_backward(
|
|
|
- span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
|
|
|
+ inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
+ grad_outputs, span_grad_prompts = await run_expert_backward(
|
|
|
+ span_uids, stub, sequence_manager.rpc_info, inputs_and_prompts, grad_outputs
|
|
|
)
|
|
|
+ grad_prompts.append(span_grad_prompts)
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
|
- inputs, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
+ inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
)
|
|
|
-
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
assert backup_forward_sequences[0].start == span.start
|
|
|
assert backup_forward_sequences[-1].end == span.end
|
|
|
|
|
|
forward_sequences.extend(backup_forward_sequences)
|
|
|
intermediate_inputs.extend(backup_intermediate_inputs)
|
|
|
- return grad_outputs
|
|
|
+
|
|
|
+ dummy_grad_prompts = [is_dummy(grad_prompt) for grad_prompt in grad_prompts]
|
|
|
+ # For now, we do not support mixed dummy and grad prompts
|
|
|
+ # Concat in num_layer dimension
|
|
|
+ grad_prompts = torch.cat(grad_prompts, dim=0) if not any(dummy_grad_prompts) else None
|
|
|
+ return grad_outputs, grad_prompts
|
|
|
|
|
|
|
|
|
-async def _gather_forward(input_batches, sequence_manager):
|
|
|
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
|
|
|
- return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
|
|
|
+ return await asyncio.gather(
|
|
|
+ *[
|
|
|
+ sequential_forward(input_batch, prompt_batch, sequence_manager)
|
|
|
+ for input_batch, prompt_batch in zip(input_batches, prompt_batches)
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
|
|
|
-async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
|
|
|
+async def _gather_backward(
|
|
|
+ grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
|
|
|
+):
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
|
|
return await asyncio.gather(
|
|
|
*[
|
|
|
- sequential_backward((grad_output,), input_batch, spans, sequence_manager)
|
|
|
- for grad_output, input_batch, spans in zip(
|
|
|
- grad_output_batches, intermediate_input_batches, forward_sequences
|
|
|
+ sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
|
|
|
+ for grad_output, input_batch, prompt_batch, spans in zip(
|
|
|
+ grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
|
|
|
)
|
|
|
]
|
|
|
)
|
|
@@ -193,18 +215,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
"""
|
|
|
|
|
|
@staticmethod
|
|
|
- def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
|
+ def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
|
|
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
|
|
+ if is_dummy(prompts):
|
|
|
+ prompt_batches = [DUMMY] * len(input_batches)
|
|
|
+ else:
|
|
|
+ prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
|
|
|
|
|
|
sequence_manager.rpc_info # lazy init
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
|
|
|
assert len(outputs) == len(input_batches)
|
|
|
|
|
|
output_batches = [output[0] for output in outputs]
|
|
|
intemediate_input_batches = [output[1] for output in outputs]
|
|
|
sequences_for_batches = [output[2] for output in outputs]
|
|
|
|
|
|
+ ctx.prompt_batches = prompt_batches
|
|
|
ctx.sequence_manager = sequence_manager
|
|
|
ctx.intemediate_input_batches = intemediate_input_batches
|
|
|
ctx.sequences_for_batches = sequences_for_batches
|
|
@@ -220,9 +247,19 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
|
|
|
assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
|
|
|
|
|
|
- grad_input_batches = RemoteExpertWorker.run_coroutine(
|
|
|
- _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ _gather_backward(
|
|
|
+ grad_output_batches,
|
|
|
+ intermediate_input_batches,
|
|
|
+ ctx.prompt_batches,
|
|
|
+ forward_sequences,
|
|
|
+ ctx.sequence_manager,
|
|
|
+ )
|
|
|
)
|
|
|
- grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
|
|
|
- grad_inputs = torch.cat(grad_inputs, dim=0)
|
|
|
- return (grad_inputs, None)
|
|
|
+ grad_input_batches = [output[0] for output in outputs]
|
|
|
+ grad_prompt_batches = [output[1] for output in outputs]
|
|
|
+
|
|
|
+ grad_inputs = torch.cat(grad_input_batches, dim=0)
|
|
|
+ dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
|
|
|
+ grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
|
|
|
+ return (grad_inputs, grad_prompts, None)
|