|
@@ -1,8 +1,9 @@
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, Sequence
|
|
|
+from typing import AsyncIterator, Dict, Sequence, Optional, List
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
|
|
|
+from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, \
|
|
|
+ serialize_torch_tensor, MSGPackSerializer
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
@@ -34,7 +35,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
try:
|
|
|
print("OPENED RPC_INFERENCE")
|
|
|
request = await anext(requests)
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
@@ -81,18 +82,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
# Parse request and prepare backends
|
|
|
- inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- hidden_states = await _rpc_forward(inputs, requested_backends)
|
|
|
+ hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
- # Serialize the overall output and respond
|
|
|
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
+ # Serialize output and respond to client
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
)
|
|
|
|
|
@@ -100,20 +101,20 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
# Parse requests and prepare backends
|
|
|
- uids_header, inputs = await self._gather_inputs(requests, context)
|
|
|
- requested_uids = self._check_header_str(uids_header)
|
|
|
+ uid_str, flat_inputs = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uid_str)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- hidden_states = await _rpc_forward(inputs, requested_backends)
|
|
|
+ hidden_states = await _rpc_forward(flat_inputs, requested_backends)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
# Serialize the overall output
|
|
|
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
serialized_output = [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
|
|
|
- # Split the serialized_output for streaming and respond
|
|
|
+ # Split the serialized_output for streaming and respond to client
|
|
|
output_split = [
|
|
|
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
]
|
|
@@ -122,11 +123,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
# Parse requests and prepare backends
|
|
|
- inputs, prompts, grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_header(request)
|
|
|
+ flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
|
|
|
+ grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
@@ -147,11 +148,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
|
|
- uids_header, (inputs, prompts, grad_outputs) = await self._gather_inputs(requests, context)
|
|
|
- requested_uids = self._check_header_str(uids_header)
|
|
|
+ uids_header, flat_tensors = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
|
|
|
+ grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
@@ -173,19 +174,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async for part in as_aiter(*output_split):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
- def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
|
|
+ def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
|
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
|
- uids = (request.uid or "").split(CHAIN_DELIMITER)
|
|
|
- if not uids:
|
|
|
- raise RuntimeError("User did not provide any uids")
|
|
|
- for uid in uids:
|
|
|
- if uid not in self.module_backends:
|
|
|
- raise RuntimeError(f"Remote peer does not serve {uid}")
|
|
|
- return tuple(uids)
|
|
|
-
|
|
|
- def _check_header_str(self, header) -> Sequence[ModuleUID]:
|
|
|
- """Check that the first request to rpc_inference is valid"""
|
|
|
- uids = (header or "").split(CHAIN_DELIMITER)
|
|
|
+ uids = (uids or "").split(CHAIN_DELIMITER)
|
|
|
if not uids:
|
|
|
raise RuntimeError("User did not provide any uids")
|
|
|
for uid in uids:
|
|
@@ -212,32 +203,42 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
yield handles
|
|
|
|
|
|
|
|
|
-async def _rpc_forward(inputs, requested_backends):
|
|
|
- # Cast inputs to backend dtype
|
|
|
- inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
|
|
|
- assert len(inputs) == 2 and inputs[0].ndim == 3
|
|
|
- hidden_states, prompts = inputs
|
|
|
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
|
|
- if is_dummy(prompts):
|
|
|
- prompts = [DUMMY] * len(requested_backends)
|
|
|
- else:
|
|
|
- pre_seq_len = prompts.shape[2]
|
|
|
+ :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
+ :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
+ :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
+ :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
+ """
|
|
|
+ hidden_states, *prompts = flat_tensors
|
|
|
+ dtype = requested_backends[0].dtype
|
|
|
+ # check parse input tensors and cast dtypes
|
|
|
+ hidden_states = hidden_states.to(dtype)
|
|
|
+ assert hidden_states.ndim == 3
|
|
|
+ assert len(prompts) <= len(requested_backends), f"Expected at most {len(requested_backends)} prompts, one per layer"
|
|
|
|
|
|
- # Run a chain of requested backends
|
|
|
+ for i in range(len(prompts)):
|
|
|
+ if not is_dummy(prompts[i]):
|
|
|
+ assert prompts[i].ndim == 3, "prompts must have shape [batch or 1, seq_len or prefix, hidden_size]"
|
|
|
+ prompts[i] = prompts[i].to(dtype)
|
|
|
+ prompts.extend((DUMMY for _ in range(len(prompts), len(requested_backends)))) # add missing prompts
|
|
|
+
|
|
|
+ seq_length = hidden_states.shape[1]
|
|
|
+
|
|
|
+ # run forward pass for requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
- if not is_dummy(prompt):
|
|
|
- hidden_states[:, :pre_seq_len] += prompt
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ hidden_states[:, :min(seq_length, prompt.shape[1]), ...] += prompt
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
- assert (
|
|
|
- hidden_states.ndim == 3
|
|
|
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
+ assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
- # Serialize the overall output
|
|
|
- return [hidden_states]
|
|
|
+ return hidden_states
|
|
|
|
|
|
|
|
|
-async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
+async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_outputs: torch.Tensor, requested_backends):
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
prompts = prompts.to(requested_backends[0].dtype)
|
|
@@ -255,6 +256,7 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
+ inputs = inputs.clone() # TODO
|
|
|
inputs[:, :pre_seq_len] += prompt
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|