{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "21e78d30", "metadata": {}, "outputs": [], "source": [ "import asyncio\n", "from typing import Sequence, Tuple, Iterable, List\n", "from tqdm.auto import trange\n", "\n", "import torch\n", "import hivemind\n", "import petals\n", "\n", "from petals.server.handler import TransformerConnectionHandler, split_for_streaming\n", "from petals.client import RemoteSequenceManager, ClientConfig\n", "from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream\n", "from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER\n", "from petals.utils.packaging import pack_args_kwargs\n", "\n", "from hivemind.compression import serialize_torch_tensor\n", "from hivemind.utils import MSGPackSerializer\n", "from hivemind.proto import runtime_pb2\n", "\n", "_END_OF_STREAM_KEY = \"_EOS\"\n", "\n", "\n", "async def _run_forward_part(\n", " uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, metadata: dict\n", ") -> List[torch.Tensor]:\n", " \"\"\"Send (serialized) inputs to run forward pass as per rpc_forward_backward; return model outputs\"\"\"\n", " assert _END_OF_STREAM_KEY not in metadata\n", " parts = [tensor_part for tensor in serialized_tensors\n", " for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)]\n", " if len(parts) > 1:\n", " serialized_metadata = MSGPackSerializer.dumps(metadata)\n", " serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))\n", " print(MSGPackSerializer.loads(serialized_metadata_last_piece))\n", " \n", " requests = [\n", " runtime_pb2.ExpertRequest(\n", " uid=uid, tensors=[tensor_part], \n", " metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)\n", " for tensor_part in parts\n", " ]\n", " \n", " print(\"CALLING stub.rpc_forward_stream on serialized inputs\")\n", " outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(requests)), config.connect_timeout)\n", " outputs = aiter_with_timeout(outputs, config.request_timeout)\n", " return await deserialize_tensor_stream(msg.tensors async for msg in outputs)\n", "\n", "\n", "\n", "async def run_remote_forward_backward(\n", " sequence_manager: RemoteSequenceManager,\n", " peer_id: PeerID,\n", " span_uids: Sequence[ModuleUID],\n", " *args: torch.Tensor,\n", " **kwargs: torch.Tensor,\n", ") -> Tuple[torch.Tensor, ...]:\n", " \"\"\"\n", " Serializes input tensors and calls \"rpc_forward_backward\" on a remote server.\n", " Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n", " but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n", " \"\"\"\n", " merged_uid = CHAIN_DELIMITER.join(span_uids)\n", " stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)\n", " metadata = sequence_manager.get_request_metadata(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n", " codecs = sequence_manager.get_compression_codecs(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n", " flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n", " flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)\n", " args_structure = metadata.setdefault(\"args_structure\", args_structure)\n", " if codecs is None:\n", " codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)\n", " else:\n", " codecs = list(nested_flatten(codecs))\n", " assert len(codecs) == len(flat_tensors), f\"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs\"\n", "\n", " # Asynchronous serialization\n", " loop = asyncio.get_running_loop()\n", " serialized_tensors = await asyncio.gather(\n", " *(\n", " loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)\n", " for tensor, compression in zip(flat_tensors, codecs)\n", " )\n", " )\n", "\n", " # call RPC on remote server\n", " size = sum(t.element_size() * t.nelement() for t in flat_tensors)\n", " # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR\n", " output_tensors = await _run_forward_part(\n", " merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata\n", " )\n", " # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591\n", " requires_grad = any(tensor.requires_grad for tensor in flat_tensors)\n", " output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]\n", " return output_tensors\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "1c47c89a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Mar 09 21:02:25.899 [\u001b[1m\u001b[34mINFO\u001b[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1\n", "Mar 09 21:02:25.903 [\u001b[1m\u001b[34mINFO\u001b[0m] Using DHT prefix: TinyLLama-v0-hf\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43eaa415b3fa4592a7200277ec4c1f47", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10 [00:00