|
@@ -0,0 +1,233 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "id": "21e78d30",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import asyncio\n",
|
|
|
+ "from typing import Sequence, Tuple, Iterable, List\n",
|
|
|
+ "from tqdm.auto import trange\n",
|
|
|
+ "\n",
|
|
|
+ "import torch\n",
|
|
|
+ "import hivemind\n",
|
|
|
+ "import petals\n",
|
|
|
+ "\n",
|
|
|
+ "from petals.server.handler import TransformerConnectionHandler, split_for_streaming\n",
|
|
|
+ "from petals.client import RemoteSequenceManager, ClientConfig\n",
|
|
|
+ "from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream\n",
|
|
|
+ "from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER\n",
|
|
|
+ "from petals.utils.packaging import pack_args_kwargs\n",
|
|
|
+ "\n",
|
|
|
+ "from hivemind.compression import serialize_torch_tensor\n",
|
|
|
+ "from hivemind.utils import MSGPackSerializer\n",
|
|
|
+ "from hivemind.proto import runtime_pb2\n",
|
|
|
+ "\n",
|
|
|
+ "_END_OF_STREAM_KEY = \"_EOS\"\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "async def _run_forward_part(\n",
|
|
|
+ " uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, metadata: dict\n",
|
|
|
+ ") -> List[torch.Tensor]:\n",
|
|
|
+ " \"\"\"Send (serialized) inputs to run forward pass as per rpc_forward_backward; return model outputs\"\"\"\n",
|
|
|
+ " assert _END_OF_STREAM_KEY not in metadata\n",
|
|
|
+ " parts = [tensor_part for tensor in serialized_tensors\n",
|
|
|
+ " for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)]\n",
|
|
|
+ " if len(parts) > 1:\n",
|
|
|
+ " serialized_metadata = MSGPackSerializer.dumps(metadata)\n",
|
|
|
+ " serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))\n",
|
|
|
+ " print(MSGPackSerializer.loads(serialized_metadata_last_piece))\n",
|
|
|
+ " \n",
|
|
|
+ " requests = [\n",
|
|
|
+ " runtime_pb2.ExpertRequest(\n",
|
|
|
+ " uid=uid, tensors=[tensor_part], \n",
|
|
|
+ " metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)\n",
|
|
|
+ " for tensor_part in parts\n",
|
|
|
+ " ]\n",
|
|
|
+ " \n",
|
|
|
+ " print(\"CALLING stub.rpc_forward_stream on serialized inputs\")\n",
|
|
|
+ " outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(requests)), config.connect_timeout)\n",
|
|
|
+ " outputs = aiter_with_timeout(outputs, config.request_timeout)\n",
|
|
|
+ " return await deserialize_tensor_stream(msg.tensors async for msg in outputs)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "async def run_remote_forward_backward(\n",
|
|
|
+ " sequence_manager: RemoteSequenceManager,\n",
|
|
|
+ " peer_id: PeerID,\n",
|
|
|
+ " span_uids: Sequence[ModuleUID],\n",
|
|
|
+ " *args: torch.Tensor,\n",
|
|
|
+ " **kwargs: torch.Tensor,\n",
|
|
|
+ ") -> Tuple[torch.Tensor, ...]:\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " Serializes input tensors and calls \"rpc_forward_backward\" on a remote server.\n",
|
|
|
+ " Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n",
|
|
|
+ " but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " merged_uid = CHAIN_DELIMITER.join(span_uids)\n",
|
|
|
+ " stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)\n",
|
|
|
+ " metadata = sequence_manager.get_request_metadata(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n",
|
|
|
+ " codecs = sequence_manager.get_compression_codecs(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n",
|
|
|
+ " flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n",
|
|
|
+ " flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)\n",
|
|
|
+ " args_structure = metadata.setdefault(\"args_structure\", args_structure)\n",
|
|
|
+ " if codecs is None:\n",
|
|
|
+ " codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)\n",
|
|
|
+ " else:\n",
|
|
|
+ " codecs = list(nested_flatten(codecs))\n",
|
|
|
+ " assert len(codecs) == len(flat_tensors), f\"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs\"\n",
|
|
|
+ "\n",
|
|
|
+ " # Asynchronous serialization\n",
|
|
|
+ " loop = asyncio.get_running_loop()\n",
|
|
|
+ " serialized_tensors = await asyncio.gather(\n",
|
|
|
+ " *(\n",
|
|
|
+ " loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)\n",
|
|
|
+ " for tensor, compression in zip(flat_tensors, codecs)\n",
|
|
|
+ " )\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # call RPC on remote server\n",
|
|
|
+ " size = sum(t.element_size() * t.nelement() for t in flat_tensors)\n",
|
|
|
+ " # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR\n",
|
|
|
+ " output_tensors = await _run_forward_part(\n",
|
|
|
+ " merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata\n",
|
|
|
+ " )\n",
|
|
|
+ " # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591\n",
|
|
|
+ " requires_grad = any(tensor.requires_grad for tensor in flat_tensors)\n",
|
|
|
+ " output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]\n",
|
|
|
+ " return output_tensors\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "id": "1c47c89a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Mar 09 21:02:25.899 [\u001b[1m\u001b[34mINFO\u001b[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1\n",
|
|
|
+ "Mar 09 21:02:25.903 [\u001b[1m\u001b[34mINFO\u001b[0m] Using DHT prefix: TinyLLama-v0-hf\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "application/vnd.jupyter.widget-view+json": {
|
|
|
+ "model_id": "43eaa415b3fa4592a7200277ec4c1f47",
|
|
|
+ "version_major": 2,
|
|
|
+ "version_minor": 0
|
|
|
+ },
|
|
|
+ "text/plain": [
|
|
|
+ " 0%| | 0/10 [00:00<?, ?it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "display_data"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "{'points': 0.0, 'active_adapter': None, 'args_structure': ((b'__T0', b'__T1'), {}), '_EOS': True}\n",
|
|
|
+ "CALLING stub.rpc_forward_stream on serialized inputs\n",
|
|
|
+ "outputs: tensor([[[ 0.2471, 0.2695, -0.0234, ..., 0.3867 ...\n",
|
|
|
+ "It works!\n",
|
|
|
+ "shutting down\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "INITIAL_PEERS = ['/ip4/172.28.126.229/tcp/33387/p2p/12D3KooWEBR2e8fGd7d8mpnt8Yc7gY4cpxfr3jM1Rvm8yMbR4rVd']\n",
|
|
|
+ "peer_id_string = \"12D3KooWEBR2e8fGd7d8mpnt8Yc7gY4cpxfr3jM1Rvm8yMbR4rVd\"\n",
|
|
|
+ "model_name = \"Maykeye/TinyLLama-v0\"\n",
|
|
|
+ "\n",
|
|
|
+ "model_config = petals.DistributedLlamaConfig.from_pretrained(model_name)\n",
|
|
|
+ "block_uids = [\n",
|
|
|
+ " f\"{model_config.dht_prefix}{UID_DELIMITER}{i}\"\n",
|
|
|
+ " for i in range(model_config.num_hidden_layers)\n",
|
|
|
+ "]\n",
|
|
|
+ "\n",
|
|
|
+ "block_in_use = block_uids[0:2]\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "try:\n",
|
|
|
+ " dht = hivemind.DHT(start=True, client_mode=True, initial_peers=INITIAL_PEERS)\n",
|
|
|
+ " sequence_manager = petals.RemoteSequenceManager(model_config, block_uids, dht=dht)\n",
|
|
|
+ " p2p = await dht.replicate_p2p()\n",
|
|
|
+ " \n",
|
|
|
+ " dummy_inputs = [torch.rand(1, 128, model_config.hidden_size, dtype=model_config.torch_dtype),\n",
|
|
|
+ " torch.empty(0, dtype=model_config.torch_dtype)]\n",
|
|
|
+ " peer_id = hivemind.PeerID.from_base58(peer_id_string)\n",
|
|
|
+ " for i in trange(10):\n",
|
|
|
+ " (outputs,) = await run_remote_forward_backward(sequence_manager, peer_id, block_in_use, *dummy_inputs)\n",
|
|
|
+ " print('outputs:', repr(outputs)[:50], '...')\n",
|
|
|
+ " print(\"It works!\")\n",
|
|
|
+ "\n",
|
|
|
+ "finally:\n",
|
|
|
+ " print(\"shutting down\")\n",
|
|
|
+ " await p2p.shutdown()\n",
|
|
|
+ " dht.shutdown() # it is okay to remove this clause, but you will be summoning a horde of daemons as you debug"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "6d85d1ad",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.11.5"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|