|
@@ -1,14 +1,17 @@
|
|
|
+"""
|
|
|
+A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
|
|
|
+"""
|
|
|
import asyncio
|
|
|
import logging
|
|
|
from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
from hivemind import serialize_torch_tensor
|
|
|
-from hivemind.moe.client.expert import expert_backward, expert_forward
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import StubBase
|
|
|
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
|
|
|
|
+from src.client.remote_forward_backward import remote_backward, remote_forward
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
@@ -17,11 +20,11 @@ from src.utils.misc import DUMMY, is_dummy
|
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
|
|
|
-async def run_expert_forward(
|
|
|
+async def run_remote_forward(
|
|
|
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
"""
|
|
|
- Serializes input tensors and calls "expert_forward".
|
|
|
+ Serializes input tensors and calls "rpc_forward" on a remote server.
|
|
|
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
"""
|
|
@@ -55,12 +58,12 @@ async def run_expert_forward(
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
|
|
|
+ deserialized_outputs = await remote_forward(uid, inputs, serialized_tensors, stub)
|
|
|
flat_outputs = tuple(deserialized_outputs)
|
|
|
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
|
-async def run_expert_backward(
|
|
|
+async def run_remote_backward(
|
|
|
uid: ModuleUID,
|
|
|
stub: StubBase,
|
|
|
rpc_info: RPCInfo,
|
|
@@ -69,7 +72,7 @@ async def run_expert_backward(
|
|
|
*extra_tensors: torch.Tensor,
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
"""
|
|
|
- Serializes grad outputs and calls "expert_backward".
|
|
|
+ Serializes grad outputs and calls "rpc_backward" on a remote server.
|
|
|
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
"""
|
|
@@ -93,7 +96,7 @@ async def run_expert_backward(
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
|
|
|
+ deserialized_grad_inputs = await remote_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
|
|
|
return deserialized_grad_inputs
|
|
|
|
|
|
|
|
@@ -130,7 +133,7 @@ async def sequential_forward(
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
|
|
|
- (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
|
|
|
+ (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
|
|
|
|
|
|
assert isinstance(outputs, torch.Tensor)
|
|
|
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
|
@@ -171,7 +174,7 @@ async def sequential_backward(
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
try:
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
- grad_outputs, *span_grad_prompts = await run_expert_backward(
|
|
|
+ grad_outputs, *span_grad_prompts = await run_remote_backward(
|
|
|
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
|
|
|
)
|
|
|
grad_outputs = [grad_outputs]
|