Browse Source

Add prompt tuning example on Personachat dataset (#69)

Artem Chumachenko 2 năm trước cách đây
mục cha
commit
1046911dea
1 tập tin đã thay đổi với 307 bổ sung0 xóa
  1. 307 0
      examples/prompt-tuning-personachat.ipynb

+ 307 - 0
examples/prompt-tuning-personachat.ipynb

@@ -0,0 +1,307 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "a07e0f5e",
+   "metadata": {},
+   "source": [
+    "<div>\n",
+    "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\">  \n",
+    "</div>\n",
+    "\n",
+    "# Distributed Bloom for Text Generation using Prompt Tuning\n",
+    "\n",
+    "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
+    "\n",
+    "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a3f8526f",
+   "metadata": {},
+   "source": [
+    "First, we have to prepare all dependencies."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b4ab6ca7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "sys.path.insert(0, \"..\")\n",
+    " \n",
+    "import torch\n",
+    "import transformers\n",
+    "import wandb\n",
+    "from datasets import load_dataset\n",
+    "from tqdm import tqdm\n",
+    "from torch.optim import AdamW\n",
+    "from torch.utils.data import DataLoader\n",
+    "from transformers import get_scheduler\n",
+    "\n",
+    "# Import a Petals model\n",
+    "from src.client.remote_model import DistributedBloomForCausalLM"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1bf07b5d",
+   "metadata": {},
+   "source": [
+    "Let's set some hyperparameters for training:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f04ba4d2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MODEL_NAME = ... # select model you like\n",
+    "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
+    "NUM_PREFIX_TOKENS = 16\n",
+    "DEVICE = 'cpu'\n",
+    "BATCH_SIZE = 4\n",
+    "LR = 1e-2\n",
+    "WEIGHT_DECAY = 0.0\n",
+    "NUM_SAMPLES = 1000\n",
+    "SEED = 42\n",
+    "MODEL_MAX_LENGTH = 256\n",
+    "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d38316bd",
+   "metadata": {},
+   "source": [
+    "Prepare tokenizer and distributed model, connect it to servers."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "03c6e53e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+    "tokenizer.padding_side = 'right'\n",
+    "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
+    "model = DistributedBloomForCausalLM.from_pretrained(\n",
+    "    MODEL_NAME, \n",
+    "    initial_peers=INITIAL_PEERS, \n",
+    "    pre_seq_len=NUM_PREFIX_TOKENS, \n",
+    "    tuning_mode=TUNING_MODE\n",
+    ").to(DEVICE)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "042e3786",
+   "metadata": {},
+   "source": [
+    "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9c44d516",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = load_dataset(\"bavard/personachat_truecased\")\n",
+    "\n",
+    "\n",
+    "def chunking(examples):\n",
+    "    inputs = [\n",
+    "        \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
+    "        for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
+    "        for candidate in candidates\n",
+    "    ]\n",
+    "    return {\"chunks\": inputs}\n",
+    "\n",
+    "\n",
+    "def tokenize(examples):\n",
+    "    outputs = {\n",
+    "        \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
+    "    }\n",
+    "    outputs[\"labels\"] = outputs[\"input_ids\"]\n",
+    "    return outputs\n",
+    "\n",
+    "\n",
+    "tokenized_datasets = (\n",
+    "    dataset\n",
+    "        .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
+    "        .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
+    ")\n",
+    "\n",
+    "\n",
+    "tokenized_datasets.set_format(\"torch\")\n",
+    "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
+    "train_dataloader = DataLoader(\n",
+    "    train_dataset.select(list(range(NUM_SAMPLES))),\n",
+    "    shuffle=True,\n",
+    "    batch_size=BATCH_SIZE,\n",
+    "    drop_last=True,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ef4323fd",
+   "metadata": {},
+   "source": [
+    "Before setting up optimizers, check the model parameters that will be trained."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9cc0ba34",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for n, p in model.named_parameters():\n",
+    "    if p.requires_grad:\n",
+    "        print(n, p.requires_grad, p.device)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "59cffce7",
+   "metadata": {},
+   "source": [
+    "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ef9bf344",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+    "\n",
+    "lr_scheduler = get_scheduler(\n",
+    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "423c56d5",
+   "metadata": {},
+   "source": [
+    "Let's initialize wandb for logging and start the training loop!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d9e46807",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wandb.init(\n",
+    "    project=\"bloom-personachat\",\n",
+    "    config={\n",
+    "        \"num_samples\": NUM_SAMPLES,\n",
+    "        \"batch_size\": BATCH_SIZE,\n",
+    "        \"learning_rate\": LR,\n",
+    "        \"weight_decay\": WEIGHT_DECAY,\n",
+    "        \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
+    "        \"model_name\": MODEL_NAME,\n",
+    "        \"seed\": SEED,\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "for batch in tqdm(train_dataloader):\n",
+    "    batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
+    "\n",
+    "    model.train()\n",
+    "    outputs = model(**batch)\n",
+    "    loss = outputs.loss\n",
+    "    loss.backward()\n",
+    "\n",
+    "    optimizer.step()\n",
+    "    lr_scheduler.step()\n",
+    "    optimizer.zero_grad()\n",
+    "\n",
+    "    wandb.log({\"Train Loss\": loss})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0f36cb80",
+   "metadata": {},
+   "source": [
+    "Try to talk with the trained model! Submit an empty input to stop the execution.\n",
+    "\n",
+    "\n",
+    "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "720181b7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MAX_TOKENS = 16\n",
+    "TOP_K = 100\n",
+    "TEMPERATURE = 0.6\n",
+    "dialog = \"\"\n",
+    "\n",
+    "while True:\n",
+    "    user_phrase = input()\n",
+    "    if len(user_phrase) == 0:\n",
+    "        break\n",
+    "    dialog += f\"{user_phrase}\\n-----\\n\"\n",
+    "    inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
+    "    outputs = model.generate(\n",
+    "        inputs,\n",
+    "        temperature=TEMPERATURE,\n",
+    "        do_sample=True,\n",
+    "        top_k=TOP_K,\n",
+    "        eos_token_id=tokenizer.eos_token_id,\n",
+    "        max_new_tokens=MAX_TOKENS,\n",
+    "    )\n",
+    "    bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
+    "    bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
+    "    print(bloom_answer)\n",
+    "    dialog += f\"{bloom_answer}\\n-----\\n\""
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}