|
@@ -6,7 +6,7 @@ import multiprocessing as mp
|
|
|
import sys
|
|
|
from enum import Enum
|
|
|
from itertools import chain
|
|
|
-from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
+from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
from async_timeout import timeout
|
|
@@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
import petals
|
|
|
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
|
|
|
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
|
|
|
from petals.server.backend import TransformerBackend
|
|
|
+from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
|
|
|
from petals.server.memory_cache import Handle
|
|
|
-from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
|
|
-from petals.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
max_length = metadata.get("max_length")
|
|
|
- active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
session_id = metadata.get("session_id")
|
|
|
if not requested_uids:
|
|
@@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
|
|
|
)
|
|
|
|
|
|
- point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
- prefix_length = 0
|
|
|
|
|
|
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
- assert len(cache_handles) == len(requested_backends)
|
|
|
- first_request = request
|
|
|
background_tasks = set()
|
|
|
- async for request, metadata in self._iterate_inference_steps(
|
|
|
- first_request, requests, session_id, requested_uids, context
|
|
|
+ async for output_tensors, can_push in iterate_rpc_inference(
|
|
|
+ requested_uids=requested_uids,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ active_adapter=self._get_active_adapter(metadata),
|
|
|
+ input_iterator=self._iterate_inference_steps(
|
|
|
+ request, requests, session_id, requested_uids, context
|
|
|
+ ),
|
|
|
+ cache_handles=cache_handles,
|
|
|
+ max_length=max_length,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ points=points,
|
|
|
):
|
|
|
- hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
|
|
-
|
|
|
- # Cast inputs to backend dtype
|
|
|
- hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
- assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
|
|
-
|
|
|
- # parse deep prompts (optional argument)
|
|
|
- has_prompts = prompts is not None and not is_dummy(prompts)
|
|
|
- if not has_prompts:
|
|
|
- prompts = [None] * len(requested_backends)
|
|
|
- else:
|
|
|
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
- prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
|
|
-
|
|
|
- if not (len(requested_backends) == len(prompts)):
|
|
|
- raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
|
-
|
|
|
- length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
|
- if prefix_length + length_increment > max_length:
|
|
|
- raise ValueError(
|
|
|
- f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
|
- f" exceeds pre-allocated maximum {max_length}"
|
|
|
- )
|
|
|
-
|
|
|
- priority = self._prioritizer.prioritize(
|
|
|
- hidden_states,
|
|
|
- hypo_ids,
|
|
|
- points=point_per_piece,
|
|
|
- requested_uids=requested_uids,
|
|
|
- type="inference",
|
|
|
- )
|
|
|
-
|
|
|
- inference_infos = tuple(
|
|
|
- InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
|
|
- for uid, handles in zip(requested_uids, cache_handles)
|
|
|
- )
|
|
|
-
|
|
|
- if hidden_states.numel() == 0:
|
|
|
- pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
|
|
|
- # when user wants to pre-allocate cache or check that server *can* allocate that cache
|
|
|
- else:
|
|
|
- assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
|
|
- (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
|
|
|
- hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
|
|
- )
|
|
|
-
|
|
|
- # serialize and send last layer outputs
|
|
|
- output_tensors = [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(
|
|
|
- (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
- )
|
|
|
- ]
|
|
|
- if not has_prompts:
|
|
|
+ if can_push:
|
|
|
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
|
|
|
background_tasks.add(task) # Keep reference until it is done to save it from GC
|
|
|
task.add_done_callback(background_tasks.discard)
|
|
|
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
|
|
|
|
|
|
- # prepare for next step
|
|
|
- prefix_length += length_increment
|
|
|
finally:
|
|
|
self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
|
|
@@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
+ hidden_states = await run_rpc_forward(
|
|
|
*flat_inputs,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
@@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
+ hidden_states = await run_rpc_forward(
|
|
|
*flat_inputs,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
@@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grads = await _rpc_backward(
|
|
|
+ grads = await run_rpc_backward(
|
|
|
*flat_tensors,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
@@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grads = await _rpc_backward(
|
|
|
+ grads = await run_rpc_backward(
|
|
|
*flat_tensors,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
@@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
result.update(block_info)
|
|
|
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
-
|
|
|
-
|
|
|
-async def _rpc_forward(
|
|
|
- *flat_tensors: torch.Tensor,
|
|
|
- requested_backends: Sequence[TransformerBackend],
|
|
|
- active_adapter: str = "",
|
|
|
- prioritizer: TaskPrioritizerBase,
|
|
|
- points: int = 0,
|
|
|
-) -> torch.Tensor:
|
|
|
- """
|
|
|
- Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
-
|
|
|
- :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
- :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
- :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
- :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
- """
|
|
|
- hidden_states, prompts = flat_tensors
|
|
|
- dtype = requested_backends[0].dtype
|
|
|
- # check parse input tensors and cast dtypes
|
|
|
- hidden_states = hidden_states.to(dtype)
|
|
|
- assert hidden_states.ndim == 3
|
|
|
- if prompts is None or is_dummy(prompts):
|
|
|
- prompts = [DUMMY] * len(requested_backends)
|
|
|
- else:
|
|
|
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
-
|
|
|
- # Run a chain of requested backends
|
|
|
- for backend, prompt in zip(requested_backends, prompts):
|
|
|
- if not is_dummy(prompt):
|
|
|
- hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
-
|
|
|
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
- priority = prioritizer.prioritize(
|
|
|
- hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
|
- )
|
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
- hidden_states,
|
|
|
- active_adapter,
|
|
|
- priority=priority,
|
|
|
- )
|
|
|
- assert isinstance(hidden_states, torch.Tensor)
|
|
|
- assert (
|
|
|
- hidden_states.ndim == 3
|
|
|
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
-
|
|
|
- return hidden_states
|
|
|
-
|
|
|
-
|
|
|
-async def _rpc_backward(
|
|
|
- *flat_tensors: torch.Tensor,
|
|
|
- requested_backends: Sequence[TransformerBackend],
|
|
|
- active_adapter: str = "",
|
|
|
- prioritizer: TaskPrioritizerBase,
|
|
|
- points: int = 0,
|
|
|
-) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
- inputs, grad_outputs, prompts = flat_tensors
|
|
|
- # Cast inputs & grad outputs to backend dtype
|
|
|
- inputs = inputs.to(requested_backends[0].dtype)
|
|
|
- grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
-
|
|
|
- if prompts is None or is_dummy(prompts):
|
|
|
- prompts = [DUMMY] * len(requested_backends)
|
|
|
- else:
|
|
|
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
-
|
|
|
- # Run a forward chain to collect intermediate inputs
|
|
|
- # Note that we do not forward for the last module since we do not need its output
|
|
|
- inter_inputs = []
|
|
|
- for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
- if not is_dummy(prompt):
|
|
|
- inputs[:, : prompt.shape[1]] += prompt
|
|
|
- inter_inputs.append(inputs)
|
|
|
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
- priority = prioritizer.prioritize(
|
|
|
- inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
- )
|
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
|
|
-
|
|
|
- assert isinstance(inputs, torch.Tensor)
|
|
|
-
|
|
|
- if not is_dummy(prompts[-1]):
|
|
|
- inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
- inter_inputs.append(inputs)
|
|
|
-
|
|
|
- assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
- grad_prompts_reversed = []
|
|
|
- # Run a chain of requested backends
|
|
|
- for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
- priority = prioritizer.prioritize(
|
|
|
- inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
- )
|
|
|
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
|
|
-
|
|
|
- assert isinstance(grad_outputs, torch.Tensor)
|
|
|
- if not is_dummy(prompt):
|
|
|
- grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
-
|
|
|
- grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
- return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|