123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
- import contextlib
- from typing import AsyncIterator, Dict, Sequence
- import torch
- from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
- from hivemind.moe.server.connection_handler import ConnectionHandler
- from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
- from hivemind.proto import runtime_pb2
- from hivemind.utils import as_aiter
- from hivemind.utils.asyncio import anext
- from hivemind.utils.streaming import split_for_streaming
- from src.data_structures import CHAIN_DELIMITER, ModuleUID
- from src.server.backend import MAX_LENGTH, TransformerBackend
- class TransformerConnectionHandler(ConnectionHandler):
- """Handles three request types: forward, backward and forward-incremental (inference)"""
- module_backends: Dict[ModuleUID, TransformerBackend]
- def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
- super().__init__(dht, module_backends)
- for module_backend in self.module_backends.values():
- assert isinstance(module_backend, TransformerBackend)
- async def rpc_inference(
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
- ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
- """Compute a single step of inference using attention cache; update attention cache accordingly."""
- try:
- print("OPENED RPC_INFERENCE")
- request = await anext(requests)
- requested_uids = self._check_header(request)
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
- cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
- prefix_length = 0
- async with self._allocate_caches(requested_backends) as cache_handles:
- assert len(cache_handles) == len(requested_backends)
- while request.tensors: # iterate while user is willing to supply tensors
- hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
- # run request tensors through all requested modules, update caches
- for backend, cache_handle in zip(requested_backends, cache_handles):
- cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
- assert (
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
- hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
- assert isinstance(hidden_states, (list, tuple))
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
- # serialize and send last layer outputs
- yield runtime_pb2.ExpertResponse(
- tensors=[
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
- for result, proto in zip(
- hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
- )
- ]
- )
- # prepare for next step
- prefix_length += hidden_states[0].shape[1]
- request = await (anext(requests))
- finally:
- print("CLOSED RPC_INFERENCE")
- async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
- # Parse request and prepare backends
- hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
- requested_uids = self._check_header(request)
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
- # Run a chain of requested backends
- for backend in requested_backends:
- assert isinstance(hidden_states, (list, tuple))
- assert (
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
- hidden_states = await backend.forward_pool.submit_task(*hidden_states)
- # Serialize the overall output and respond
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
- return runtime_pb2.ExpertResponse(
- tensors=[
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
- ]
- )
- async def rpc_forward_stream(
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
- ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
- # Parse requests and prepare backends
- uids_header, hidden_states = await self._gather_inputs(requests, context)
- requested_uids = self._check_header_str(uids_header)
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
- # Run a chain of requested backends
- for backend in requested_backends:
- assert isinstance(hidden_states, (list, tuple))
- assert (
- len(hidden_states) == 1 and hidden_states[0].ndim == 3
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
- hidden_states = await backend.forward_pool.submit_task(*hidden_states)
- # Serialize the overall output
- assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
- serialized_output = [
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
- for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
- ]
- # Split the serialized_output for streaming and respond
- output_split = [
- part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
- ]
- async for part in as_aiter(*output_split):
- yield runtime_pb2.ExpertResponse(tensors=[part])
- async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
- # Parse requests and prepare backends
- inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
- requested_uids = self._check_header(request)
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
- # Run a forward chain to collect intermediate inputs
- # Note that we do not forward for the last module since we do not need its output
- inter_inputs = [inputs]
- for backend in requested_backends[:-1]:
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
- inputs = await backend.forward_pool.submit_task(inputs)
- assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
- inputs = inputs[0]
- inter_inputs.append(inputs)
- # Run a chain of requested backends
- for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
- inputs_and_grads = [inp, grads]
- grads = await backend.backward_pool.submit_task(*inputs_and_grads)
- assert isinstance(grads, (list, tuple)) and len(grads) == 1
- grads = grads[0]
- # Serialize the overall grad_input and respond
- return runtime_pb2.ExpertResponse(
- tensors=[
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
- for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
- ]
- )
- async def rpc_backward_stream(
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
- ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
- uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
- inputs, grads = inputs_and_grads
- requested_uids = self._check_header_str(uids_header)
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
- # Run a forward chain to collect intermediate inputs
- # Note that we do not forward for the last module since we do not need its outputs
- inter_inputs = [inputs]
- for backend in requested_backends[:-1]:
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
- inputs = await backend.forward_pool.submit_task(inputs)
- assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
- inputs = inputs[0]
- inter_inputs.append(inputs)
- # Run a backward chain for requested backends
- for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
- inputs_and_grads = [inp, grads]
- grads = await backend.backward_pool.submit_task(*inputs_and_grads)
- assert isinstance(grads, (list, tuple)) and len(grads) == 1
- grads = grads[0]
- # Serialize the overall grad_inputs
- serialized_grad_inputs = [
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
- for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
- ]
- # Split the serialized_grad_inputs for streaming and respond
- output_split = [
- part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
- ]
- async for part in as_aiter(*output_split):
- yield runtime_pb2.ExpertResponse(tensors=[part])
- def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
- """Check that the first request to rpc_inference is valid"""
- uids = (request.uid or "").split(CHAIN_DELIMITER)
- if not uids:
- raise RuntimeError("User did not provide any uids")
- for uid in uids:
- if uid not in self.module_backends:
- raise RuntimeError(f"Remote peer does not serve {uid}")
- return tuple(uids)
- def _check_header_str(self, header) -> Sequence[ModuleUID]:
- """Check that the first request to rpc_inference is valid"""
- uids = (header or "").split(CHAIN_DELIMITER)
- if not uids:
- raise RuntimeError("User did not provide any uids")
- for uid in uids:
- if uid not in self.module_backends:
- raise RuntimeError(f"Remote peer does not serve {uid}")
- return tuple(uids)
- @contextlib.asynccontextmanager
- async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
- """Allocate memory caches for each transformer block, return cache handles"""
- async with contextlib.AsyncExitStack() as stack:
- handles = []
- for backend in backends:
- num_heads = backend.module.self_attention.num_heads
- head_dim = backend.module.self_attention.head_dim
- cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
- # [key_or_value, batch_size, max_length, num_heads, head_dim]
- handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
- yield handles
|