|
@@ -16,11 +16,13 @@ from src.server.handler import TransformerConnectionHandler
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
|
|
|
|
-async def run_forward(
|
|
|
|
|
|
+async def run_expert_forward(
|
|
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
|
|
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
"""
|
|
"""
|
|
- TODO: add description
|
|
|
|
|
|
+ Serializes input tensors and calls "expert_forward".
|
|
|
|
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
|
|
|
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
"""
|
|
"""
|
|
|
|
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
@@ -47,7 +49,7 @@ async def run_forward(
|
|
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
|
|
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
|
|
-async def run_backward(
|
|
|
|
|
|
+async def run_expert_backward(
|
|
uid: ModuleUID,
|
|
uid: ModuleUID,
|
|
stub: StubBase,
|
|
stub: StubBase,
|
|
rpc_info: RPCInfo,
|
|
rpc_info: RPCInfo,
|
|
@@ -55,7 +57,9 @@ async def run_backward(
|
|
grad_outputs: List[torch.Tensor],
|
|
grad_outputs: List[torch.Tensor],
|
|
) -> Sequence[torch.Tensor]:
|
|
) -> Sequence[torch.Tensor]:
|
|
"""
|
|
"""
|
|
- TODO: add description
|
|
|
|
|
|
+ Serializes grad outputs and calls "expert_backward".
|
|
|
|
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
|
|
|
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
"""
|
|
"""
|
|
|
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
@@ -70,11 +74,13 @@ async def run_backward(
|
|
return deserialized_grad_inputs
|
|
return deserialized_grad_inputs
|
|
|
|
|
|
|
|
|
|
-async def async_forward(
|
|
|
|
|
|
+async def sequential_forward(
|
|
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
|
|
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
|
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
"""
|
|
"""
|
|
- TODO: add description
|
|
|
|
|
|
+ Constructs a routing path from <start_index> to <end_index>.
|
|
|
|
+ Performs chained forward for each subsequence of blocks on the path.
|
|
|
|
+ If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
|
|
"""
|
|
"""
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
|
|
@@ -92,7 +98,7 @@ async def async_forward(
|
|
span = sequences.pop(0)
|
|
span = sequences.pop(0)
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
- (outputs,) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
|
|
|
|
|
|
+ (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
|
|
|
|
|
|
assert isinstance(outputs, torch.Tensor)
|
|
assert isinstance(outputs, torch.Tensor)
|
|
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
|
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
|
@@ -112,14 +118,15 @@ async def async_forward(
|
|
return outputs, intermediate_inputs, done_sequences
|
|
return outputs, intermediate_inputs, done_sequences
|
|
|
|
|
|
|
|
|
|
-async def async_backward(
|
|
|
|
|
|
+async def sequential_backward(
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
intermediate_inputs: Sequence[torch.Tensor],
|
|
intermediate_inputs: Sequence[torch.Tensor],
|
|
forward_sequences: Sequence[RemoteSpanInfo],
|
|
forward_sequences: Sequence[RemoteSpanInfo],
|
|
sequence_manager: RemoteSequenceManager,
|
|
sequence_manager: RemoteSequenceManager,
|
|
) -> Sequence[torch.Tensor]:
|
|
) -> Sequence[torch.Tensor]:
|
|
"""
|
|
"""
|
|
- TODO: add description
|
|
|
|
|
|
+ Performs chained backward for each forward subsequence.
|
|
|
|
+ If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
|
|
"""
|
|
"""
|
|
|
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
@@ -134,11 +141,11 @@ async def async_backward(
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
|
|
|
- grad_outputs = await run_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
|
|
|
|
|
|
+ grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
|
|
break
|
|
break
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
- _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
|
|
|
|
|
|
+ _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
inputs, sequence_manager, start_index=span.start, end_index=span.end
|
|
inputs, sequence_manager, start_index=span.start, end_index=span.end
|
|
)
|
|
)
|
|
|
|
|
|
@@ -152,13 +159,15 @@ async def async_backward(
|
|
|
|
|
|
|
|
|
|
async def _gather_forward(input_batches, sequence_manager):
|
|
async def _gather_forward(input_batches, sequence_manager):
|
|
- return await asyncio.gather(*[async_forward(input_batch, sequence_manager) for input_batch in input_batches])
|
|
|
|
|
|
+ """ Wrapper for asyncio.gather to perform parallel sequential forwards """
|
|
|
|
+ return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
|
|
|
|
|
|
|
|
|
|
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
|
|
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
|
|
|
|
+ """ Wrapper for asyncio.gather to perform parallel sequential backwards """
|
|
return await asyncio.gather(
|
|
return await asyncio.gather(
|
|
*[
|
|
*[
|
|
- async_backward((grad_output,), input_batch, spans, sequence_manager)
|
|
|
|
|
|
+ sequential_backward((grad_output,), input_batch, spans, sequence_manager)
|
|
for grad_output, input_batch, spans in zip(
|
|
for grad_output, input_batch, spans in zip(
|
|
grad_output_batches, intermediate_input_batches, forward_sequences
|
|
grad_output_batches, intermediate_input_batches, forward_sequences
|
|
)
|
|
)
|
|
@@ -168,8 +177,8 @@ async def _gather_backward(grad_output_batches, intermediate_input_batches, forw
|
|
|
|
|
|
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
"""
|
|
"""
|
|
- A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
|
|
|
|
- :note: this function splits input data into batches for efficient parallel processing
|
|
|
|
|
|
+ PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
|
|
|
|
+ This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
|
|
"""
|
|
"""
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|