|
@@ -1,43 +1,68 @@
|
|
|
-from typing import Any, Dict, Optional, Tuple
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+from concurrent.futures import Future
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
|
|
|
-from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
|
|
|
-from hivemind.utils.grpc import ChannelCache
|
|
|
+from hivemind import moe
|
|
|
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
|
|
|
+from hivemind.dht import DHT
|
|
|
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
+from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
+from hivemind.proto import runtime_pb2
|
|
|
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
|
|
+from hivemind.utils.mpfuture import MPFuture
|
|
|
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
|
+from hivemind.utils.serializer import MSGPackSerializer
|
|
|
+from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
|
|
|
|
|
|
|
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
|
|
|
- """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
|
|
|
- channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
|
|
|
- return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
|
|
|
+def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
|
|
|
+ return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(frozen=True)
|
|
|
+class RemoteExpertInfo:
|
|
|
+ """A simple data class containing uid of expert and server PeerInfo"""
|
|
|
+
|
|
|
+ uid: str
|
|
|
+ peer_info: PeerInfo
|
|
|
|
|
|
|
|
|
class RemoteExpert(nn.Module):
|
|
|
"""
|
|
|
A simple module that runs forward/backward of an expert hosted on a remote machine.
|
|
|
Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
|
|
|
-
|
|
|
Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
|
|
|
Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
|
|
|
|
|
|
- :param uid: unique expert identifier
|
|
|
- :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
|
|
|
+ :param expert_info: RemoteExpertInfo with uid and server PeerInfo
|
|
|
+ :param p2p: P2P instance connected to the running p2pd
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, uid, endpoint: Endpoint):
|
|
|
+ def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
|
|
|
super().__init__()
|
|
|
- self.uid, self.endpoint = uid, endpoint
|
|
|
- self._info = None
|
|
|
+ self._info, self.p2p = expert_info, p2p
|
|
|
+ self._rpc_info = None
|
|
|
|
|
|
@property
|
|
|
- def stub(self):
|
|
|
- return _get_expert_stub(self.endpoint)
|
|
|
+ def uid(self):
|
|
|
+ return self._info.uid
|
|
|
+
|
|
|
+ @property
|
|
|
+ def server_peer_info(self):
|
|
|
+ return self._info.peer_info
|
|
|
+
|
|
|
+ @property
|
|
|
+ def stub(self) -> StubBase:
|
|
|
+ return get_expert_stub(self.p2p, self.server_peer_info)
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
"""Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
|
|
@@ -52,18 +77,125 @@ class RemoteExpert(nn.Module):
|
|
|
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
|
|
|
flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
|
|
|
+
|
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
|
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
|
|
|
|
|
|
@property
|
|
|
def info(self):
|
|
|
- if self._info is None:
|
|
|
- outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
|
|
|
- self._info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
- return self._info
|
|
|
+ if self._rpc_info is None:
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
+ self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
+ return self._rpc_info
|
|
|
|
|
|
def extra_repr(self):
|
|
|
- return f"uid={self.uid}, endpoint={self.endpoint}"
|
|
|
+ return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
|
|
|
+
|
|
|
+
|
|
|
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
|
|
|
+ experts: List[Optional[RemoteExpert]] = []
|
|
|
+ for info in infos:
|
|
|
+ if info is not None:
|
|
|
+ experts.append(RemoteExpert(info, p2p))
|
|
|
+ else:
|
|
|
+ experts.append(None)
|
|
|
+ return experts
|
|
|
+
|
|
|
+
|
|
|
+def create_remote_experts(
|
|
|
+ infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
|
+) -> Union[List[Optional[RemoteExpert]], Future]:
|
|
|
+ if return_future:
|
|
|
+
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return _create_remote_experts(await infos_future, p2p)
|
|
|
+
|
|
|
+ return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
+
|
|
|
+ p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
+ return _create_remote_experts(infos, p2p)
|
|
|
+
|
|
|
+
|
|
|
+def batch_create_remote_experts(
|
|
|
+ infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
|
|
|
+ dht: DHT,
|
|
|
+ return_future: bool = False,
|
|
|
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
|
|
|
+ if return_future:
|
|
|
+
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return [_create_remote_experts(i, p2p) for i in await infos_future]
|
|
|
+
|
|
|
+ return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
+
|
|
|
+ return [create_remote_experts(exps, dht) for exps in infos]
|
|
|
+
|
|
|
+
|
|
|
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
|
|
|
+
|
|
|
+ grad_inputs = await stub.rpc_backward_stream(
|
|
|
+ amap_in_executor(
|
|
|
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
+ iter_as_aiter(split),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
|
|
+ return await deserialize_tensor_stream(tensors_stream)
|
|
|
+
|
|
|
+
|
|
|
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
+ )
|
|
|
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
+
|
|
|
+
|
|
|
+async def expert_backward(
|
|
|
+ uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
+) -> List[torch.Tensor]:
|
|
|
+ size = 0
|
|
|
+ for t in inputs_and_grads:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size > DEFAULT_MAX_MSG_SIZE:
|
|
|
+ return await _backward_stream(uid, serialized_tensors, stub)
|
|
|
+ else:
|
|
|
+ return await _backward_unary(uid, serialized_tensors, stub)
|
|
|
+
|
|
|
+
|
|
|
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
|
|
+
|
|
|
+ outputs = await stub.rpc_forward_stream(
|
|
|
+ amap_in_executor(
|
|
|
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
+ iter_as_aiter(split),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
|
|
|
+ return await deserialize_tensor_stream(tensors_stream)
|
|
|
+
|
|
|
+
|
|
|
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
+ )
|
|
|
+ return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
+
|
|
|
+
|
|
|
+async def expert_forward(
|
|
|
+ uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
+) -> List[torch.Tensor]:
|
|
|
+ size = 0
|
|
|
+ for t in inputs:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size > DEFAULT_MAX_MSG_SIZE:
|
|
|
+ return await _forward_stream(uid, serialized_tensors, stub)
|
|
|
+ else:
|
|
|
+ return await _forward_unary(uid, serialized_tensors, stub)
|
|
|
|
|
|
|
|
|
class _RemoteModuleCall(torch.autograd.Function):
|
|
@@ -74,7 +206,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
ctx,
|
|
|
dummy: torch.Tensor,
|
|
|
uid: str,
|
|
|
- stub: runtime_grpc.ConnectionHandlerStub,
|
|
|
+ stub: "ConnectionHandlerStub",
|
|
|
info: Dict[str, Any],
|
|
|
*inputs: torch.Tensor,
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
@@ -83,15 +215,11 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
|
|
|
ctx.uid, ctx.stub, ctx.info = uid, stub, info
|
|
|
ctx.save_for_backward(*inputs)
|
|
|
-
|
|
|
- serialized_tensors = [
|
|
|
- serialize_torch_tensor(inp, proto.compression)
|
|
|
- for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
- ]
|
|
|
-
|
|
|
- outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
-
|
|
|
- deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
+ serialized_tensors = (
|
|
|
+ serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
+ )
|
|
|
+ deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
|
|
|
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
@@ -101,12 +229,12 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
- serialized_tensors = [
|
|
|
+ serialized_tensors = (
|
|
|
serialize_torch_tensor(tensor, proto.compression)
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
- ]
|
|
|
-
|
|
|
- grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
+ )
|
|
|
+ deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
|
|
|
+ )
|
|
|
|
|
|
- deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, *deserialized_grad_inputs)
|