|
@@ -1,8 +1,16 @@
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, Sequence
|
|
|
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
|
|
|
+from hivemind import (
|
|
|
+ DHT,
|
|
|
+ MSGPackSerializer,
|
|
|
+ P2PContext,
|
|
|
+ TensorDescriptor,
|
|
|
+ deserialize_torch_tensor,
|
|
|
+ nested_flatten,
|
|
|
+ serialize_torch_tensor,
|
|
|
+)
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
@@ -12,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
|
+from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
@@ -33,7 +42,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
try:
|
|
|
print("OPENED RPC_INFERENCE")
|
|
|
request = await anext(requests)
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
@@ -80,27 +89,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
# Parse request and prepare backends
|
|
|
- hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- # Cast inputs to backend dtype
|
|
|
- hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
+ hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
- # Run a chain of requested backends
|
|
|
- for backend in requested_backends:
|
|
|
- assert isinstance(hidden_states, (list, tuple))
|
|
|
- assert (
|
|
|
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
- hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
-
|
|
|
- # Serialize the overall output and respond
|
|
|
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
+ # Serialize output and respond to client
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
)
|
|
|
|
|
@@ -108,29 +108,20 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
# Parse requests and prepare backends
|
|
|
- uids_header, hidden_states = await self._gather_inputs(requests, context)
|
|
|
- requested_uids = self._check_header_str(uids_header)
|
|
|
+ uid_str, flat_inputs = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uid_str)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- # Cast inputs to backend dtype
|
|
|
- hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
-
|
|
|
- # Run a chain of requested backends
|
|
|
- for backend in requested_backends:
|
|
|
- assert isinstance(hidden_states, (list, tuple))
|
|
|
- assert (
|
|
|
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
- hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
+ hidden_states = await _rpc_forward(flat_inputs, requested_backends)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
# Serialize the overall output
|
|
|
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
serialized_output = [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
|
|
|
- # Split the serialized_output for streaming and respond
|
|
|
+ # Split the serialized_output for streaming and respond to client
|
|
|
output_split = [
|
|
|
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
]
|
|
@@ -139,36 +130,25 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
# Parse requests and prepare backends
|
|
|
- inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- # Cast inputs & grad outputs to backend dtype
|
|
|
- inputs = inputs.to(requested_backends[0].dtype)
|
|
|
- grads = grads.to(requested_backends[-1].dtype)
|
|
|
-
|
|
|
- # Run a forward chain to collect intermediate inputs
|
|
|
- # Note that we do not forward for the last module since we do not need its output
|
|
|
- inter_inputs = [inputs]
|
|
|
- for backend in requested_backends[:-1]:
|
|
|
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
- inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
- assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
|
|
- inputs = inputs[0]
|
|
|
- inter_inputs.append(inputs)
|
|
|
-
|
|
|
- # Run a chain of requested backends
|
|
|
- for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
- inputs_and_grads = [inp, grads]
|
|
|
- grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
- assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
|
|
- grads = grads[0]
|
|
|
+ grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
+
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ ) # TODO generalize
|
|
|
|
|
|
# Serialize the overall grad_input and respond
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
]
|
|
|
)
|
|
|
|
|
@@ -176,36 +156,23 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
|
|
- uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
- inputs, grads = inputs_and_grads
|
|
|
- requested_uids = self._check_header_str(uids_header)
|
|
|
+ uids_header, flat_tensors = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- # Cast inputs & grad outputs to backend dtype
|
|
|
- inputs = inputs.to(requested_backends[0].dtype)
|
|
|
- grads = grads.to(requested_backends[-1].dtype)
|
|
|
-
|
|
|
- # Run a forward chain to collect intermediate inputs
|
|
|
- # Note that we do not forward for the last module since we do not need its outputs
|
|
|
- inter_inputs = [inputs]
|
|
|
- for backend in requested_backends[:-1]:
|
|
|
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
- inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
- assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
|
|
- inputs = inputs[0]
|
|
|
- inter_inputs.append(inputs)
|
|
|
-
|
|
|
- # Run a backward chain for requested backends
|
|
|
- for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
- inputs_and_grads = [inp, grads]
|
|
|
- grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
- assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
|
|
- grads = grads[0]
|
|
|
+ grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
+
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ ) # TODO generalize
|
|
|
|
|
|
# Serialize the overall grad_inputs
|
|
|
serialized_grad_inputs = [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
]
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
output_split = [
|
|
@@ -215,19 +182,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async for part in as_aiter(*output_split):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
- def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
|
|
+ def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
|
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
|
- uids = (request.uid or "").split(CHAIN_DELIMITER)
|
|
|
- if not uids:
|
|
|
- raise RuntimeError("User did not provide any uids")
|
|
|
- for uid in uids:
|
|
|
- if uid not in self.module_backends:
|
|
|
- raise RuntimeError(f"Remote peer does not serve {uid}")
|
|
|
- return tuple(uids)
|
|
|
-
|
|
|
- def _check_header_str(self, header) -> Sequence[ModuleUID]:
|
|
|
- """Check that the first request to rpc_inference is valid"""
|
|
|
- uids = (header or "").split(CHAIN_DELIMITER)
|
|
|
+ uids = (uids or "").split(CHAIN_DELIMITER)
|
|
|
if not uids:
|
|
|
raise RuntimeError("User did not provide any uids")
|
|
|
for uid in uids:
|
|
@@ -252,3 +209,83 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
|
|
|
|
|
|
yield handles
|
|
|
+
|
|
|
+
|
|
|
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
+
|
|
|
+ :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
+ :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
+ :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
+ :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
+ """
|
|
|
+ hidden_states, *prompts = flat_tensors
|
|
|
+ dtype = requested_backends[0].dtype
|
|
|
+ # check parse input tensors and cast dtypes
|
|
|
+ hidden_states = hidden_states.to(dtype)
|
|
|
+ assert hidden_states.ndim == 3
|
|
|
+ if not prompts or is_dummy(prompts[0]):
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ pre_seq_len = 0
|
|
|
+ else:
|
|
|
+ prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ pre_seq_len = prompts[0].shape[-2]
|
|
|
+
|
|
|
+ # Run a chain of requested backends
|
|
|
+ for backend, prompt in zip(requested_backends, prompts):
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ hidden_states[:, :pre_seq_len] += prompt
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor)
|
|
|
+ assert (
|
|
|
+ hidden_states.ndim == 3
|
|
|
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
+
|
|
|
+ # Serialize the overall output
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+
|
|
|
+async def _rpc_backward(
|
|
|
+ *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
|
|
|
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
+ inputs, grad_outputs, *prompts = flat_tensors
|
|
|
+ # Cast inputs & grad outputs to backend dtype
|
|
|
+ inputs = inputs.to(requested_backends[0].dtype)
|
|
|
+ grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
+
|
|
|
+ if not prompts or is_dummy(prompts[0]):
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ pre_seq_len = 0
|
|
|
+ else:
|
|
|
+ prompts = [prompts[0].to(requested_backends[0].dtype)]
|
|
|
+ prompts = [p.squeeze(0) for p in prompts[0].split(1)]
|
|
|
+ pre_seq_len = prompts[0].shape[-2]
|
|
|
+
|
|
|
+ # Run a forward chain to collect intermediate inputs
|
|
|
+ # Note that we do not forward for the last module since we do not need its output
|
|
|
+ inter_inputs = []
|
|
|
+ for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
+ assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ inputs[:, :pre_seq_len] += prompt
|
|
|
+ inter_inputs.append(inputs)
|
|
|
+ (inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
+ assert isinstance(inputs, torch.Tensor)
|
|
|
+
|
|
|
+ if not is_dummy(prompts[-1]):
|
|
|
+ inputs[:, :pre_seq_len] += prompts[-1]
|
|
|
+ inter_inputs.append(inputs)
|
|
|
+
|
|
|
+ assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
+ grad_prompts_reversed = []
|
|
|
+ # Run a chain of requested backends
|
|
|
+ for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
+ assert isinstance(grad_outputs, torch.Tensor)
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
+
|
|
|
+ grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
+ return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|