|
@@ -1,5 +1,5 @@
|
|
import contextlib
|
|
import contextlib
|
|
-from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from hivemind import (
|
|
from hivemind import (
|
|
@@ -7,6 +7,7 @@ from hivemind import (
|
|
MSGPackSerializer,
|
|
MSGPackSerializer,
|
|
P2PContext,
|
|
P2PContext,
|
|
TensorDescriptor,
|
|
TensorDescriptor,
|
|
|
|
+ deserialize_tensor_stream,
|
|
deserialize_torch_tensor,
|
|
deserialize_torch_tensor,
|
|
nested_flatten,
|
|
nested_flatten,
|
|
serialize_torch_tensor,
|
|
serialize_torch_tensor,
|
|
@@ -14,12 +15,13 @@ from hivemind import (
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.proto import runtime_pb2
|
|
-from hivemind.utils import as_aiter
|
|
|
|
-from hivemind.utils.asyncio import anext
|
|
|
|
|
|
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
from src.server.backend import TransformerBackend
|
|
from src.server.backend import TransformerBackend
|
|
|
|
+from src.server.task_pool import PrioritizedTaskPool
|
|
|
|
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
|
@@ -28,11 +30,41 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
module_backends: Dict[ModuleUID, TransformerBackend]
|
|
module_backends: Dict[ModuleUID, TransformerBackend]
|
|
|
|
|
|
- def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
|
|
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ dht: DHT,
|
|
|
|
+ module_backends: Dict[str, TransformerBackend],
|
|
|
|
+ inference_max_length: int,
|
|
|
|
+ task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
|
|
|
|
+ ):
|
|
super().__init__(dht, module_backends)
|
|
super().__init__(dht, module_backends)
|
|
for module_backend in self.module_backends.values():
|
|
for module_backend in self.module_backends.values():
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
self.inference_max_length = inference_max_length
|
|
self.inference_max_length = inference_max_length
|
|
|
|
+ self._prioritizer = task_prioritizer
|
|
|
|
+
|
|
|
|
+ async def _gather_inputs(
|
|
|
|
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
+ ) -> Tuple[str, List[torch.Tensor], Dict]:
|
|
|
|
+ block_uid, metadata = None, None
|
|
|
|
+
|
|
|
|
+ def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
|
|
|
|
+ nonlocal block_uid, metadata
|
|
|
|
+
|
|
|
|
+ if block_uid is None:
|
|
|
|
+ block_uid = req.uid
|
|
|
|
+ elif block_uid != req.uid:
|
|
|
|
+ raise ValueError("Block uids differ in one request")
|
|
|
|
+
|
|
|
|
+ if metadata is None:
|
|
|
|
+ metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
|
|
|
|
+
|
|
|
|
+ return req.tensors
|
|
|
|
+
|
|
|
|
+ tensors_stream = amap_in_executor(_unpack, requests)
|
|
|
|
+ inputs = await deserialize_tensor_stream(tensors_stream)
|
|
|
|
+ assert isinstance(block_uid, str) and isinstance(metadata, dict)
|
|
|
|
+ return block_uid, inputs, metadata
|
|
|
|
|
|
async def rpc_inference(
|
|
async def rpc_inference(
|
|
self,
|
|
self,
|
|
@@ -47,13 +79,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
max_length = metadata.get("max_length")
|
|
max_length = metadata.get("max_length")
|
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
|
|
if not requested_uids:
|
|
if not requested_uids:
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
|
|
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
|
|
|
|
+ assert isinstance(
|
|
|
|
+ points, (float, int)
|
|
|
|
+ ), f"rpc_inference should have number of points as a number or None, got {points}"
|
|
if not 0 <= max_length <= self.inference_max_length:
|
|
if not 0 <= max_length <= self.inference_max_length:
|
|
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
|
|
|
|
|
|
+ point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
|
|
|
cache_metadata = torch.tensor(
|
|
cache_metadata = torch.tensor(
|
|
@@ -98,8 +135,19 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
assert (
|
|
assert (
|
|
hidden_states.ndim == 3
|
|
hidden_states.ndim == 3
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
+ assert isinstance(
|
|
|
|
+ backend.inference_pool, PrioritizedTaskPool
|
|
|
|
+ ), "petals support only prioritized pools"
|
|
|
|
+ priority = self._prioritizer.prioritize(
|
|
|
|
+ cache_metadata,
|
|
|
|
+ hidden_states,
|
|
|
|
+ hypo_ids,
|
|
|
|
+ points=point_per_piece / len(requested_backends),
|
|
|
|
+ backend=backend,
|
|
|
|
+ type="inference",
|
|
|
|
+ )
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
- cache_metadata, hidden_states, hypo_ids
|
|
|
|
|
|
+ cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
)
|
|
)
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
# serialize and send last layer outputs
|
|
@@ -123,8 +171,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
requested_uids = self._check_uids(request.uid)
|
|
requested_uids = self._check_uids(request.uid)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
-
|
|
|
|
- hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
|
|
|
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
+ assert isinstance(
|
|
|
|
+ points, (float, int)
|
|
|
|
+ ), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
|
+
|
|
|
|
+ hidden_states = await _rpc_forward(
|
|
|
|
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
+ )
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
# Serialize output and respond to client
|
|
# Serialize output and respond to client
|
|
@@ -139,11 +194,17 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
# Parse requests and prepare backends
|
|
# Parse requests and prepare backends
|
|
- uid_str, flat_inputs = await self._gather_inputs(requests, context)
|
|
|
|
|
|
+ uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
|
|
requested_uids = self._check_uids(uid_str)
|
|
requested_uids = self._check_uids(uid_str)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
+ assert isinstance(
|
|
|
|
+ points, (float, int)
|
|
|
|
+ ), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
|
|
|
|
|
|
+ hidden_states = await _rpc_forward(
|
|
|
|
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
+ )
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
|
|
|
|
# Serialize the overall output
|
|
# Serialize the overall output
|
|
@@ -164,8 +225,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
requested_uids = self._check_uids(request.uid)
|
|
requested_uids = self._check_uids(request.uid)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
-
|
|
|
|
- grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
+ assert isinstance(
|
|
|
|
+ points, (float, int)
|
|
|
|
+ ), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
+
|
|
|
|
+ grads = await _rpc_backward(
|
|
|
|
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
+ )
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
@@ -187,11 +255,17 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
|
|
- uids_header, flat_tensors = await self._gather_inputs(requests, context)
|
|
|
|
|
|
+ uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
|
|
requested_uids = self._check_uids(uids_header)
|
|
requested_uids = self._check_uids(uids_header)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+ points = metadata.get("points", 0)
|
|
|
|
+ assert isinstance(
|
|
|
|
+ points, (float, int)
|
|
|
|
+ ), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
+ grads = await _rpc_backward(
|
|
|
|
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
+ )
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
@@ -244,7 +318,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
yield handles
|
|
yield handles
|
|
|
|
|
|
|
|
|
|
-async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
|
|
|
|
|
|
+async def _rpc_forward(
|
|
|
|
+ *flat_tensors: torch.Tensor,
|
|
|
|
+ requested_backends: Sequence[TransformerBackend],
|
|
|
|
+ prioritizer: TaskPrioritizerBase,
|
|
|
|
+ points: int = 0,
|
|
|
|
+) -> torch.Tensor:
|
|
"""
|
|
"""
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
|
|
@@ -267,7 +346,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
if not is_dummy(prompt):
|
|
if not is_dummy(prompt):
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
|
|
|
+
|
|
|
|
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
+ priority = prioritizer.prioritize(
|
|
|
|
+ hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
|
|
+ )
|
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
|
+ hidden_states,
|
|
|
|
+ priority=priority,
|
|
|
|
+ )
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
assert (
|
|
assert (
|
|
hidden_states.ndim == 3
|
|
hidden_states.ndim == 3
|
|
@@ -278,7 +365,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
|
|
|
|
|
|
|
async def _rpc_backward(
|
|
async def _rpc_backward(
|
|
- *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
|
|
|
|
|
|
+ *flat_tensors: torch.Tensor,
|
|
|
|
+ requested_backends: Sequence[TransformerBackend],
|
|
|
|
+ prioritizer: TaskPrioritizerBase,
|
|
|
|
+ points: int = 0,
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
inputs, grad_outputs, prompts = flat_tensors
|
|
inputs, grad_outputs, prompts = flat_tensors
|
|
# Cast inputs & grad outputs to backend dtype
|
|
# Cast inputs & grad outputs to backend dtype
|
|
@@ -298,7 +388,12 @@ async def _rpc_backward(
|
|
if not is_dummy(prompt):
|
|
if not is_dummy(prompt):
|
|
inputs[:, : prompt.shape[1]] += prompt
|
|
inputs[:, : prompt.shape[1]] += prompt
|
|
inter_inputs.append(inputs)
|
|
inter_inputs.append(inputs)
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
|
|
|
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
+ priority = prioritizer.prioritize(
|
|
|
|
+ inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
|
+ )
|
|
|
|
+ (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
|
|
|
|
+
|
|
assert isinstance(inputs, torch.Tensor)
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
|
|
|
if not is_dummy(prompts[-1]):
|
|
if not is_dummy(prompts[-1]):
|
|
@@ -309,7 +404,12 @@ async def _rpc_backward(
|
|
grad_prompts_reversed = []
|
|
grad_prompts_reversed = []
|
|
# Run a chain of requested backends
|
|
# Run a chain of requested backends
|
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
|
|
|
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
+ priority = prioritizer.prioritize(
|
|
|
|
+ inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
|
+ )
|
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
|
|
|
|
+
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
if not is_dummy(prompt):
|
|
if not is_dummy(prompt):
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|