|
@@ -7,6 +7,9 @@ from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.utils.asyncio import anext
|
|
from hivemind.utils.asyncio import anext
|
|
|
|
+from hivemind.utils.streaming import split_for_streaming
|
|
|
|
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
|
+from hivemind.utils import as_aiter
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
@@ -67,6 +70,140 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
finally:
|
|
finally:
|
|
print("CLOSED RPC_INFERENCE")
|
|
print("CLOSED RPC_INFERENCE")
|
|
|
|
|
|
|
|
+ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
+ # Parse request and prepare backends
|
|
|
|
+ hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
+ requested_uids = self._check_header(request)
|
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+
|
|
|
|
+ # Run a chain of requested backends
|
|
|
|
+ for backend in requested_backends:
|
|
|
|
+ assert isinstance(hidden_states, (list, tuple))
|
|
|
|
+ assert (
|
|
|
|
+ len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
+ hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
|
+
|
|
|
|
+ # Serialize the overall output and respond
|
|
|
|
+ assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
+ return runtime_pb2.ExpertResponse(tensors=[
|
|
|
|
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
+ for result, proto in zip(
|
|
|
|
+ hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
+ )
|
|
|
|
+ ])
|
|
|
|
+
|
|
|
|
+ async def rpc_forward_stream(
|
|
|
|
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
+ # Parse requests and prepare backends
|
|
|
|
+ uids_header, hidden_states = await self._gather_inputs(requests, context)
|
|
|
|
+ requested_uids = self._check_header_str(uids_header)
|
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+
|
|
|
|
+ # Run a chain of requested backends
|
|
|
|
+ for backend in requested_backends:
|
|
|
|
+ assert isinstance(hidden_states, (list, tuple))
|
|
|
|
+ assert (
|
|
|
|
+ len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
+ hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
|
+
|
|
|
|
+ # Serialize the overall output
|
|
|
|
+ assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
+ serialized_output = [
|
|
|
|
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
+ for result, proto in zip(
|
|
|
|
+ hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ # Split the serialized_output for streaming and respond
|
|
|
|
+ output_split = [
|
|
|
|
+ part
|
|
|
|
+ for tensor in serialized_output
|
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
+ ]
|
|
|
|
+ async for part in as_aiter(*output_split):
|
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
+
|
|
|
|
+ async def rpc_backward(
|
|
|
|
+ self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
|
+ ) -> runtime_pb2.ExpertResponse:
|
|
|
|
+ # Parse requests and prepare backends
|
|
|
|
+ inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
+ requested_uids = self._check_header(request)
|
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+
|
|
|
|
+ # Run a forward chain to collect intermediate inputs
|
|
|
|
+ # Note that we do not forward for the last module since we do not need its output
|
|
|
|
+ inter_inputs = [inputs]
|
|
|
|
+ for backend in requested_backends[:-1]:
|
|
|
|
+ assert (inputs.ndim == 3
|
|
|
|
+ ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
+ inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
|
+ assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
|
|
|
|
+ inputs = inputs[0]
|
|
|
|
+ inter_inputs.append(inputs)
|
|
|
|
+
|
|
|
|
+ # Run a chain of requested backends
|
|
|
|
+ for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
|
+ inputs_and_grads = [inp, grads]
|
|
|
|
+ grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
|
+ assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
|
|
|
|
+ grads = grads[0]
|
|
|
|
+
|
|
|
|
+ # Serialize the overall grad_input and respond
|
|
|
|
+ return runtime_pb2.ExpertResponse(tensors=[
|
|
|
|
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
+ for result, proto in zip(
|
|
|
|
+ [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
|
|
|
|
+ )
|
|
|
|
+ ])
|
|
|
|
+
|
|
|
|
+ async def rpc_backward_stream(
|
|
|
|
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
+ ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
+ uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
|
+ inputs, grads = inputs_and_grads
|
|
|
|
+ requested_uids = self._check_header_str(uids_header)
|
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+
|
|
|
|
+ # Run a forward chain to collect intermediate inputs
|
|
|
|
+ # Note that we do not forward for the last module since we do not need its outputs
|
|
|
|
+ inter_inputs = [inputs]
|
|
|
|
+ for backend in requested_backends[:-1]:
|
|
|
|
+ assert (inputs.ndim == 3
|
|
|
|
+ ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
+ inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
|
+ assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
|
|
|
|
+ inputs = inputs[0]
|
|
|
|
+ inter_inputs.append(inputs)
|
|
|
|
+
|
|
|
|
+ # Run a backward chain for requested backends
|
|
|
|
+ for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
|
+ inputs_and_grads = [inp, grads]
|
|
|
|
+ grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
|
+ assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
|
|
|
|
+ grads = grads[0]
|
|
|
|
+
|
|
|
|
+ # Serialize the overall grad_inputs
|
|
|
|
+ serialized_grad_inputs = [
|
|
|
|
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
+ for result, proto in zip(
|
|
|
|
+ [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+ # Split the serialized_grad_inputs for streaming and respond
|
|
|
|
+ output_split = [
|
|
|
|
+ part
|
|
|
|
+ for tensor in serialized_grad_inputs
|
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ async for part in as_aiter(*output_split):
|
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
+
|
|
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
|
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
uids = (request.uid or "").split(CHAIN_DELIMITER)
|
|
uids = (request.uid or "").split(CHAIN_DELIMITER)
|
|
@@ -77,6 +214,16 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
raise RuntimeError(f"Remote peer does not serve {uid}")
|
|
raise RuntimeError(f"Remote peer does not serve {uid}")
|
|
return tuple(uids)
|
|
return tuple(uids)
|
|
|
|
|
|
|
|
+ def _check_header_str(self, header) -> Sequence[ModuleUID]:
|
|
|
|
+ """Check that the first request to rpc_inference is valid"""
|
|
|
|
+ uids = (header or "").split(CHAIN_DELIMITER)
|
|
|
|
+ if not uids:
|
|
|
|
+ raise RuntimeError("User did not provide any uids")
|
|
|
|
+ for uid in uids:
|
|
|
|
+ if uid not in self.module_backends:
|
|
|
|
+ raise RuntimeError(f"Remote peer does not serve {uid}")
|
|
|
|
+ return tuple(uids)
|
|
|
|
+
|
|
@contextlib.asynccontextmanager
|
|
@contextlib.asynccontextmanager
|
|
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
|
|
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|