{ "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": {}, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed Bloom for Text Generation using Prompt Tuning\n", "\n", "In this example, we showcase how the a test 6B version of BLOOM model can be efficiently adapted in a decentralized fashion using Petals. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n", "\n", "This example will train the BLOOM model for the chatbot task. In a given dialogue context the model has to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)." ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": {}, "source": [ "First, we have to prepare all dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": {}, "outputs": [], "source": [ "%env OMP_NUM_THREADS=24\n", "\n", "import os\n", "import sys\n", "import inspect\n", "sys.path.insert(0, \"..\")\n", "\n", "# General \n", "import torch\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "\n", "# Distributed\n", "import hivemind\n", "from src.client.remote_model import DistributedBloomForCausalLM\n", "\n", "# HF imports\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset\n", "from transformers import get_scheduler\n", "\n", "# Visualization dependencies\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": {}, "source": [ "Set some hyperparameters for training." ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME = ... # select model you like\n", "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE ='cpu'\n", "BATCH_SIZE = 4\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_SAMPLES = 1000\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 256\n", "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": {}, "source": [ "Prepare tokenizer and distributed model, connect it to servers." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": {}, "outputs": [], "source": [ "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "model = DistributedBloomForCausalLM.from_pretrained(\n", " MODEL_NAME, \n", " initial_peers=INITIAL_PEERS, \n", " pre_seq_len=NUM_PREFIX_TOKENS, \n", " tuning_mode=TUNING_MODE\n", ").to(DEVICE)" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": {}, "source": [ "Prepare personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"bavard/personachat_truecased\")\n", "\n", "\n", "def chunking(examples):\n", " inputs = [\n", " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n", " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n", " for candidate in candidates\n", " ]\n", " return {\"chunks\": inputs}\n", "\n", "\n", "def tokenize(examples):\n", " outputs = {\n", " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n", " }\n", " outputs[\"labels\"] = outputs[\"input_ids\"]\n", " return outputs\n", "\n", "\n", "tokenized_datasets = (\n", " dataset\n", " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n", " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n", ")\n", "\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "train_dataloader = DataLoader(\n", " train_dataset.select(list(range(NUM_SAMPLES))),\n", " shuffle=True,\n", " batch_size=BATCH_SIZE,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": {}, "source": [ "Before setting up optimizers, check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": {}, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": {}, "source": [ "The optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": {}, "source": [ "Let's initialize wandb for logging and start the training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": {}, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-personachat\",\n", " config={\n", " \"num_samples\": NUM_SAMPLES,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " model.train()\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss})" ] }, { "cell_type": "markdown", "id": "0f36cb80", "metadata": {}, "source": [ "Try to talk with the trained model! Submit an empty input to stop the execution.\n", "\n", "\n", "In this example we have to pass the whole dialogue. In the future, we will support a much faster interactive dialogue mode." ] }, { "cell_type": "code", "execution_count": null, "id": "720181b7", "metadata": {}, "outputs": [], "source": [ "MAX_TOKENS = 16\n", "TOP_K = 100\n", "TEMPERATURE = 0.6\n", "dialog = \"\"\n", "\n", "while True:\n", " user_phrase = input()\n", " if len(user_phrase) == 0:\n", " break\n", " dialog += f\"{user_phrase}\\n-----\\n\"\n", " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n", " outputs = model.generate(\n", " inputs,\n", " temperature=TEMPERATURE,\n", " do_sample=True,\n", " top_k=TOP_K,\n", " eos_token_id=tokenizer.eos_token_id,\n", " max_new_tokens=MAX_TOKENS,\n", " )\n", " bloom_answer = tokenizer.batch_decode(outputs)[0]\n", " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n", " print(bloom_answer)\n", " dialog += f\"{bloom_answer}\\n-----\\n\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }