|
@@ -7,15 +7,13 @@ import uuid
|
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
|
|
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
-from hivemind.p2p import P2P
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
+from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
|
|
|
|
|
|
-from petals.client.config import ClientConfig
|
|
|
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
|
|
|
-from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
|
|
|
from petals.utils.packaging import pack_args_kwargs
|
|
@@ -32,23 +30,21 @@ class _ServerInferenceSession:
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- config: ClientConfig,
|
|
|
+ sequence_manager: RemoteSequenceManager,
|
|
|
span: RemoteSpanInfo,
|
|
|
span_uids: Sequence[ModuleUID],
|
|
|
- rpc_info: RPCInfo,
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
- outputs_aiter: AsyncIterator,
|
|
|
- *,
|
|
|
+ outputs_stream: AsyncIterator,
|
|
|
+ *block_kwargs,
|
|
|
max_length: int,
|
|
|
- **metadata,
|
|
|
):
|
|
|
- self.config = config
|
|
|
- self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
|
|
+ self.sequence_manager = sequence_manager
|
|
|
+ self.span, self.span_uids = span, span_uids
|
|
|
self.num_blocks = len(span_uids)
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
- self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
+ self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
|
|
|
self.session_id = str(uuid.uuid4())
|
|
|
- self.session_metadata = dict(max_length=max_length, **metadata)
|
|
|
+ self.max_length = max_length
|
|
|
self.stepped = False
|
|
|
self.closed = False
|
|
|
|
|
@@ -56,24 +52,26 @@ class _ServerInferenceSession:
|
|
|
self.history = None # Used in case of server failures to regenerate attention caches on new servers
|
|
|
self.next_session = None
|
|
|
|
|
|
+ self.block_kwargs = block_kwargs
|
|
|
+ assert len(self.block_kwargs) in (0, self.num_blocks)
|
|
|
+
|
|
|
@classmethod
|
|
|
async def create(
|
|
|
cls,
|
|
|
- config: ClientConfig,
|
|
|
- p2p: P2P,
|
|
|
+ sequence_manager: RemoteSequenceManager,
|
|
|
span: RemoteSpanInfo,
|
|
|
- span_uids: Sequence[RemoteSpanInfo],
|
|
|
- rpc_info: RPCInfo,
|
|
|
- **metadata,
|
|
|
+ span_uids: Sequence[ModuleUID],
|
|
|
+ *block_kwargs: Dict[str, Any],
|
|
|
+ **kwargs,
|
|
|
) -> _ServerInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
- stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
|
|
+ stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
- config.connect_timeout,
|
|
|
+ sequence_manager.config.connect_timeout,
|
|
|
)
|
|
|
- return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
+ return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
|
|
|
|
|
|
@staticmethod
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
@@ -87,7 +85,7 @@ class _ServerInferenceSession:
|
|
|
self,
|
|
|
inputs: torch.Tensor,
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
- *block_kwargs: Dict[str, Any],
|
|
|
+ *,
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
step_id: str,
|
|
|
) -> torch.Tensor:
|
|
@@ -96,7 +94,6 @@ class _ServerInferenceSession:
|
|
|
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
|
|
"""
|
|
|
- # TODO record previous kwargs in case of server failure!!!
|
|
|
if self.closed:
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
@@ -112,10 +109,11 @@ class _ServerInferenceSession:
|
|
|
|
|
|
if not self.stepped:
|
|
|
inputs = self.history # Pass full inputs including prefix
|
|
|
+ block_kwargs = self.block_kwargs
|
|
|
else:
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
+ block_kwargs = []
|
|
|
|
|
|
- assert len(block_kwargs) in (0, self.span.length)
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
@@ -131,39 +129,50 @@ class _ServerInferenceSession:
|
|
|
assert len(hypo_ids) == len(inputs)
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
- # serialize inputs and put them into the queue
|
|
|
- input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
|
|
-
|
|
|
- request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
|
- if not self.stepped:
|
|
|
- request_metadata.update(self.session_metadata)
|
|
|
- elif self.config.use_server_to_server:
|
|
|
+ metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
|
|
|
+ metadata.update(
|
|
|
+ self.sequence_manager.get_request_metadata(
|
|
|
+ self.span.peer_id,
|
|
|
+ "rpc_inference",
|
|
|
+ self.span_uids,
|
|
|
+ inputs,
|
|
|
+ prompts,
|
|
|
+ *block_kwargs,
|
|
|
+ max_length=self.max_length,
|
|
|
+ session_id=self.session_id,
|
|
|
+ step_id=step_id,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ if self.stepped and self.sequence_manager.config.use_server_to_server:
|
|
|
next_servers = self._collect_next_servers()
|
|
|
if next_servers:
|
|
|
- request_metadata["next_servers"] = next_servers
|
|
|
+ metadata["next_servers"] = next_servers
|
|
|
|
|
|
- args_structure = request_metadata.setdefault("args_structure", args_structure)
|
|
|
+ codecs = self.sequence_manager.get_compression_codecs(
|
|
|
+ self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
|
|
|
+ )
|
|
|
|
|
|
- # TODO YOZH FIX THIS BEFORE THE END OF THIS PR
|
|
|
- # TODO: make possible to use different compression method for different tensors
|
|
|
- server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
|
|
|
- compression = server_side_inference_schema[0].compression
|
|
|
- inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
|
|
|
+ # serialize inputs and put them into the queue
|
|
|
+ input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
|
|
+ args_structure = metadata.setdefault("args_structure", args_structure)
|
|
|
|
|
|
- # TODO: create more explicit way to check servers schema and client's structure
|
|
|
- assert len(input_tensors) >= len(
|
|
|
- server_side_inference_schema
|
|
|
- ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
|
|
|
+ if codecs is None:
|
|
|
+ codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
|
|
|
+ else:
|
|
|
+ codecs = list(nested_flatten(codecs))
|
|
|
+ assert len(codecs) == len(
|
|
|
+ input_tensors
|
|
|
+ ), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
|
|
|
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
self._step(
|
|
|
runtime_pb2.ExpertRequest(
|
|
|
uid=CHAIN_DELIMITER.join(self.span_uids),
|
|
|
tensors=[
|
|
|
- serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
|
|
- for tensor, proto in zip(input_tensors, inference_schema)
|
|
|
+ serialize_torch_tensor(tensor, compression)
|
|
|
+ for tensor, compression in zip(input_tensors, codecs)
|
|
|
],
|
|
|
- metadata=MSGPackSerializer.dumps(request_metadata),
|
|
|
+ metadata=MSGPackSerializer.dumps(metadata),
|
|
|
)
|
|
|
)
|
|
|
)
|
|
@@ -190,7 +199,7 @@ class _ServerInferenceSession:
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
self.stepped = True
|
|
|
- return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
|
|
|
+ return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
|
|
|
|
|
|
def close(self):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
@@ -227,7 +236,7 @@ class InferenceSession:
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
|
|
|
self._sequence_manager = sequence_manager
|
|
|
self._closed = False
|
|
|
self._server_sessions = []
|
|
@@ -235,6 +244,12 @@ class InferenceSession:
|
|
|
self._max_length = max_length
|
|
|
self.output_ids = None
|
|
|
|
|
|
+ num_blocks = len(self._sequence_manager)
|
|
|
+ if len(block_kwargs) == 1:
|
|
|
+ block_kwargs = block_kwargs * num_blocks
|
|
|
+ assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
|
|
+ self.block_kwargs = block_kwargs
|
|
|
+
|
|
|
@property
|
|
|
def num_blocks(self) -> int:
|
|
|
return len(self._sequence_manager)
|
|
@@ -247,17 +262,13 @@ class InferenceSession:
|
|
|
server_sessions = []
|
|
|
try:
|
|
|
for span in chosen_spans:
|
|
|
- span_uids = self._sequence_manager.block_uids[span.start : span.end]
|
|
|
- metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
_ServerInferenceSession.create(
|
|
|
- self._sequence_manager.config,
|
|
|
- self._sequence_manager.state.p2p,
|
|
|
+ self._sequence_manager,
|
|
|
span,
|
|
|
- span_uids,
|
|
|
- rpc_info=self._sequence_manager.rpc_info, # TODO not actually needed
|
|
|
+ self._sequence_manager.block_uids[span.start : span.end],
|
|
|
+ *self.block_kwargs[span.start : span.end],
|
|
|
max_length=self._max_length,
|
|
|
- **metadata,
|
|
|
)
|
|
|
)
|
|
|
server_sessions.append(session)
|
|
@@ -282,18 +293,13 @@ class InferenceSession:
|
|
|
self,
|
|
|
inputs: torch.Tensor,
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
- *block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
+
|
|
|
assert not self._closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
|
|
- num_blocks = len(self._sequence_manager)
|
|
|
- if len(block_kwargs) == 1:
|
|
|
- block_kwargs = block_kwargs * num_blocks
|
|
|
- assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
|
|
-
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
@@ -326,9 +332,8 @@ class InferenceSession:
|
|
|
inputs = server_session.step(
|
|
|
inputs,
|
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
|
- *block_kwargs[server_session.span.start : server_session.span.end],
|
|
|
- step_id=step_id,
|
|
|
hypo_ids=hypo_ids,
|
|
|
+ step_id=step_id,
|
|
|
)
|
|
|
|
|
|
server_idx += 1
|
|
@@ -354,7 +359,7 @@ class InferenceSession:
|
|
|
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
|
|
|
return outputs
|
|
|
|
|
|
- def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
|
|
|
+ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
|
|
|
# If there is a failed server session, this code closes it
|
|
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|