|
@@ -1,55 +1,77 @@
|
|
|
"""
|
|
|
Utility functions that call RPC forward or backward on a single remote server
|
|
|
"""
|
|
|
-from typing import Iterable, List, Sequence
|
|
|
+import asyncio
|
|
|
+from typing import Iterable, List, Sequence, Tuple, Optional
|
|
|
|
|
|
import torch
|
|
|
+from hivemind import nested_compare, nested_flatten, serialize_torch_tensor, nested_pack
|
|
|
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
|
|
|
+from hivemind.p2p import StubBase
|
|
|
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
-
|
|
|
-async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
|
|
|
-
|
|
|
- grad_inputs = await stub.rpc_backward_stream(
|
|
|
- amap_in_executor(
|
|
|
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
- iter_as_aiter(split),
|
|
|
- ),
|
|
|
+from src.data_structures import ModuleUID, RPCInfo
|
|
|
+
|
|
|
+
|
|
|
+async def run_remote_forward(
|
|
|
+ uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b'', **kwargs
|
|
|
+) -> Tuple[torch.Tensor, ...]:
|
|
|
+ """
|
|
|
+ Serializes input tensors and calls "rpc_forward" on a remote server.
|
|
|
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
|
|
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
+ # detach to avoid pickling the computation graph
|
|
|
+ assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
|
|
|
+ kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
|
|
|
+
|
|
|
+ # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
|
|
+ forward_inputs = (inputs, kwargs)
|
|
|
+
|
|
|
+ # Modify forward_schema to support prompts
|
|
|
+ args_schema, kwargs_schema = rpc_info["forward_schema"]
|
|
|
+ # TODO: rm this assert when support arbitrary number of input tensors
|
|
|
+ assert len(args_schema) == 1 and len(inputs) == 2
|
|
|
+ forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
|
|
|
+
|
|
|
+ if not nested_compare(forward_inputs, forward_schema_with_prompts):
|
|
|
+ raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
+
|
|
|
+ forward_inputs = nested_flatten(forward_inputs)
|
|
|
+ inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
|
+
|
|
|
+ # Asynchronous serialization
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ serialized_tensors = await asyncio.gather(
|
|
|
+ *(
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
|
|
|
+ )
|
|
|
)
|
|
|
- tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
|
|
- return await deserialize_tensor_stream(tensors_stream)
|
|
|
|
|
|
+ # call RPC on remote server
|
|
|
+ size = sum(t.element_size() * t.nelement() for t in inputs)
|
|
|
+ if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
+ deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
|
|
|
+ else:
|
|
|
+ deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
|
-async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
- )
|
|
|
- return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
+ return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
|
-async def remote_backward(
|
|
|
- uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
+async def _forward_stream(
|
|
|
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
) -> List[torch.Tensor]:
|
|
|
- """Call rpc_backward (unary or stream) on a single remote server, return grads w.r.t. arguments"""
|
|
|
- size = 0
|
|
|
- for t in inputs_and_grads:
|
|
|
- size += t.element_size() * t.nelement()
|
|
|
- if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
- return await _backward_stream(uid, serialized_tensors, stub)
|
|
|
- else:
|
|
|
- return await _backward_unary(uid, serialized_tensors, stub)
|
|
|
-
|
|
|
-
|
|
|
-async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
|
|
|
|
|
outputs = await stub.rpc_forward_stream(
|
|
|
amap_in_executor(
|
|
|
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
|
|
iter_as_aiter(split),
|
|
|
),
|
|
|
)
|
|
@@ -58,21 +80,78 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
|
|
|
return await deserialize_tensor_stream(tensors_stream)
|
|
|
|
|
|
|
|
|
-async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+async def _forward_unary(
|
|
|
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
+) -> List[torch.Tensor]:
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
|
|
)
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
-async def remote_forward(
|
|
|
- uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
+
|
|
|
+async def _backward_stream(
|
|
|
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
) -> List[torch.Tensor]:
|
|
|
- """Call rpc_forward (unary or stream) on a single remote server, return block outputs"""
|
|
|
- size = 0
|
|
|
- for t in inputs:
|
|
|
- size += t.element_size() * t.nelement()
|
|
|
- if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
- return await _forward_stream(uid, serialized_tensors, stub)
|
|
|
+ split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
|
|
|
+
|
|
|
+ grad_inputs = await stub.rpc_backward_stream(
|
|
|
+ amap_in_executor(
|
|
|
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
|
|
+ iter_as_aiter(split),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
|
|
+ return await deserialize_tensor_stream(tensors_stream)
|
|
|
+
|
|
|
+
|
|
|
+async def run_remote_backward(
|
|
|
+ uid: ModuleUID,
|
|
|
+ stub: StubBase,
|
|
|
+ rpc_info: RPCInfo,
|
|
|
+ inputs: torch.Tensor,
|
|
|
+ grad_outputs: List[torch.Tensor],
|
|
|
+ *extra_tensors: torch.Tensor,
|
|
|
+ metadata: bytes = b''
|
|
|
+) -> Sequence[torch.Tensor]:
|
|
|
+ """
|
|
|
+ Serializes grad outputs and calls "rpc_backward" on a remote server.
|
|
|
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
|
|
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
+ """
|
|
|
+
|
|
|
+ grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
+ inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
|
|
|
+
|
|
|
+ # Modify forward_schema to support prompts
|
|
|
+ args_schema, kwargs_schema = rpc_info["forward_schema"]
|
|
|
+ assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
|
|
|
+ # TODO generalize this
|
|
|
+ prompts_schema = next(iter(args_schema))
|
|
|
+ backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
|
|
|
+
|
|
|
+ # Asynchronous serialization
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ serialized_tensors = await asyncio.gather(
|
|
|
+ *(
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
|
|
|
+ if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
+ deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
|
|
|
else:
|
|
|
- return await _forward_unary(uid, serialized_tensors, stub)
|
|
|
+ deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
|
|
|
+
|
|
|
+ return deserialized_grad_inputs
|
|
|
+
|
|
|
+
|
|
|
+async def _backward_unary(
|
|
|
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
+) -> List[torch.Tensor]:
|
|
|
+ grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
|
|
+ )
|
|
|
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|