|
@@ -3,8 +3,8 @@ import logging
|
|
from typing import List, Optional, Sequence, Tuple
|
|
from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-from hivemind import serialize_torch_tensor
|
|
|
|
-from hivemind.moe.client.expert import expert_backward, expert_forward, _forward_stream
|
|
|
|
|
|
+from hivemind import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
|
+from hivemind.moe.client.expert import _forward_stream, expert_backward, expert_forward
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.p2p import StubBase
|
|
from hivemind.p2p import StubBase
|
|
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
@@ -39,11 +39,14 @@ async def run_expert_forward(
|
|
forward_inputs = nested_flatten(forward_inputs)
|
|
forward_inputs = nested_flatten(forward_inputs)
|
|
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
|
|
|
|
|
|
+ # Asynchronous serialization
|
|
loop = asyncio.get_running_loop()
|
|
loop = asyncio.get_running_loop()
|
|
- serialized_tensors = await asyncio.gather(*(
|
|
|
|
- loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
|
- for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
|
|
|
|
- ))
|
|
|
|
|
|
+ serialized_tensors = await asyncio.gather(
|
|
|
|
+ *(
|
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
|
|
deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
|
|
deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
|
|
flat_outputs = tuple(deserialized_outputs)
|
|
flat_outputs = tuple(deserialized_outputs)
|
|
@@ -67,12 +70,15 @@ async def run_expert_backward(
|
|
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
|
|
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
|
|
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
|
|
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
|
|
|
|
|
|
|
|
+ # Asynchronous serialization
|
|
loop = asyncio.get_running_loop()
|
|
loop = asyncio.get_running_loop()
|
|
- serialized_tensors = await asyncio.gather(*(
|
|
|
|
- loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
|
- for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
|
- ))
|
|
|
|
-
|
|
|
|
|
|
+ serialized_tensors = await asyncio.gather(
|
|
|
|
+ *(
|
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
|
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
|
|
deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
|
|
return deserialized_grad_inputs
|
|
return deserialized_grad_inputs
|
|
|
|
|
|
@@ -189,7 +195,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
@staticmethod
|
|
def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
|
- input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
|
|
|
|
|
|
+ input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
|
|
|
|
|
sequence_manager.rpc_info # lazy init
|
|
sequence_manager.rpc_info # lazy init
|
|
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
|
|
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
|
|
@@ -217,6 +223,11 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
grad_input_batches = RemoteExpertWorker.run_coroutine(
|
|
grad_input_batches = RemoteExpertWorker.run_coroutine(
|
|
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
|
|
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
|
|
)
|
|
)
|
|
|
|
+ # grad_input_batches = [sequential_backward((grad_output,), input_batch, spans, ctx.sequence_manager)
|
|
|
|
+ # for grad_output, input_batch, spans in zip(
|
|
|
|
+ # grad_output_batches, intermediate_input_batches, forward_sequences
|
|
|
|
+ # )
|
|
|
|
+ # ]
|
|
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
|
|
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
|
|
grad_inputs = torch.cat(grad_inputs, dim=0)
|
|
grad_inputs = torch.cat(grad_inputs, dim=0)
|
|
return (grad_inputs, None)
|
|
return (grad_inputs, None)
|