소스 검색

Add sst-2 ipynb example (#86)

- Add sst-2 example of a prompt-based training
- Have some enhancement in the persona-chat example
Artem Chumachenko 2 년 전
부모
커밋
0d9c7de0bd
3개의 변경된 파일340개의 추가작업 그리고 10개의 파일을 삭제
  1. 3 1
      README.md
  2. 11 9
      examples/prompt-tuning-personachat.ipynb
  3. 326 0
      examples/prompt-tuning-sst2.ipynb

+ 3 - 1
README.md

@@ -155,7 +155,9 @@ loss.backward()
 print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
 ```
 
-Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
+Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning:
+- Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb).
+- A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
 
 Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.
 

+ 11 - 9
examples/prompt-tuning-personachat.ipynb

@@ -33,7 +33,6 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# This block is only need for colab users. It will change nothing if you are running this notebook locally.\n",
     "import subprocess\n",
     "import sys\n",
     "\n",
@@ -41,14 +40,18 @@
     "IN_COLAB = 'google.colab' in sys.modules\n",
     "\n",
     "if IN_COLAB:\n",
-    "    subprocess.run(['git', 'clone', 'https://github.com/bigscience-workshop/petals'])\n",
-    "    subprocess.run(['pip', 'install', '-r', 'petals/requirements.txt'])\n",
-    "    subprocess.run(['pip', 'install', 'datasets', 'lib64'])\n",
+    "    subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n",
+    "    subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n",
+    "    subprocess.run(\"pip install datasets wandb\", shell=True)\n",
     "\n",
     "    try:\n",
     "        subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
     "    except subprocess.CalledProcessError as e:\n",
-    "        subprocess.run(['rm', '-r', '/usr/local/cuda/lib64'])"
+    "        subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n",
+    "\n",
+    "    sys.path.insert(0, './petals/')\n",
+    "else:\n",
+    "    sys.path.insert(0, \"..\")"
    ]
   },
   {
@@ -60,7 +63,6 @@
    "source": [
     "import os\n",
     "import sys\n",
-    "sys.path.insert(0, \"..\") # for colab change to sys.path.insert(0, './petals/')\n",
     " \n",
     "import torch\n",
     "import transformers\n",
@@ -312,7 +314,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.10 64-bit",
+   "display_name": "Python 3.8.0 ('petals')",
    "language": "python",
    "name": "python3"
   },
@@ -326,11 +328,11 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.9"
+   "version": "3.8.0"
   },
   "vscode": {
    "interpreter": {
-    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
+    "hash": "a303c9f329a09f921588ea6ef03898c90b4a8e255a47e0bd6e36f6331488f609"
    }
   }
  },

+ 326 - 0
examples/prompt-tuning-sst2.ipynb

@@ -0,0 +1,326 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "a07e0f5e",
+   "metadata": {},
+   "source": [
+    "<div>\n",
+    "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\">  \n",
+    "</div>\n",
+    "\n",
+    "# Distributed Bloom for Text Classification using Prompt Tuning\n",
+    "\n",
+    "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
+    "\n",
+    "We will adapt the BLOOM model for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
+    "\n",
+    "To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a3f8526f",
+   "metadata": {},
+   "source": [
+    "First, we have to prepare all dependencies."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "73bbc648",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import subprocess\n",
+    "import sys\n",
+    "\n",
+    "\n",
+    "IN_COLAB = 'google.colab' in sys.modules\n",
+    "\n",
+    "if IN_COLAB:\n",
+    "    subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n",
+    "    subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n",
+    "    subprocess.run(\"pip install datasets wandb\", shell=True)\n",
+    "\n",
+    "    try:\n",
+    "        subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
+    "    except subprocess.CalledProcessError as e:\n",
+    "        subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n",
+    "\n",
+    "    sys.path.insert(0, './petals/')\n",
+    "else:\n",
+    "    sys.path.insert(0, \"..\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b4ab6ca7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    " \n",
+    "import torch\n",
+    "import transformers\n",
+    "import wandb\n",
+    "from datasets import load_dataset, load_metric\n",
+    "from tqdm import tqdm\n",
+    "from torch.optim import AdamW\n",
+    "from torch.utils.data import DataLoader\n",
+    "from transformers import get_scheduler\n",
+    "\n",
+    "# Import a Petals model\n",
+    "from src.client.remote_model import DistributedBloomForSequenceClassification"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1bf07b5d",
+   "metadata": {},
+   "source": [
+    "Let's set some hyperparameters for training:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f04ba4d2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MODEL_NAME = ... # select model you like\n",
+    "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
+    "NUM_PREFIX_TOKENS = 16\n",
+    "DEVICE = 'cpu'\n",
+    "BATCH_SIZE = 4\n",
+    "LR = 1e-2\n",
+    "WEIGHT_DECAY = 0.0\n",
+    "NUM_SAMPLES = 1000\n",
+    "NUM_EPOCHS = 3\n",
+    "SEED = 42\n",
+    "MODEL_MAX_LENGTH = 64\n",
+    "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d38316bd",
+   "metadata": {},
+   "source": [
+    "Prepare tokenizer and distributed model, connect it to servers."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "03c6e53e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+    "tokenizer.padding_side = 'right'\n",
+    "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
+    "model = DistributedBloomForSequenceClassification.from_pretrained(\n",
+    "    MODEL_NAME, \n",
+    "    initial_peers=INITIAL_PEERS, \n",
+    "    pre_seq_len=NUM_PREFIX_TOKENS, \n",
+    "    tuning_mode=TUNING_MODE\n",
+    ").to(DEVICE)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "042e3786",
+   "metadata": {},
+   "source": [
+    "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9c44d516",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "task = 'sst2'\n",
+    "\n",
+    "dataset = load_dataset(\"glue\", task)\n",
+    "\n",
+    "def preprocess_function(examples):\n",
+    "    return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n",
+    "\n",
+    "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
+    "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
+    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
+    "tokenized_datasets.set_format(\"torch\")\n",
+    "\n",
+    "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
+    "valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n",
+    "\n",
+    "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n",
+    "valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2a3f3590",
+   "metadata": {},
+   "source": [
+    "To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1e1812be",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metric = load_metric('glue', task)\n",
+    "\n",
+    "def eval_metrics(model, dataloader, device='cpu'):\n",
+    "    model.eval()\n",
+    "    for batch in dataloader:\n",
+    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
+    "        \n",
+    "        with torch.no_grad():\n",
+    "            outputs = model(**batch)\n",
+    "\n",
+    "        logits = outputs.logits\n",
+    "        predictions = torch.argmax(logits, dim=-1)\n",
+    "        metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n",
+    "    model.train()\n",
+    "    return metric.compute()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ef4323fd",
+   "metadata": {},
+   "source": [
+    "Before setting up optimizers, check the model parameters that will be trained."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9cc0ba34",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for n, p in model.named_parameters():\n",
+    "    if p.requires_grad:\n",
+    "        print(n, p.requires_grad, p.device)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "59cffce7",
+   "metadata": {},
+   "source": [
+    "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ef9bf344",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+    "\n",
+    "lr_scheduler = get_scheduler(\n",
+    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "423c56d5",
+   "metadata": {},
+   "source": [
+    "Let's initialize wandb for logging and start the training loop!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d9e46807",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wandb.init(\n",
+    "    project=\"bloom-sst-2\",\n",
+    "    config={\n",
+    "        \"num_epochs\": NUM_EPOCHS,\n",
+    "        \"num_samples\": NUM_SAMPLES,\n",
+    "        \"batch_size\": BATCH_SIZE,\n",
+    "        \"learning_rate\": LR,\n",
+    "        \"weight_decay\": WEIGHT_DECAY,\n",
+    "        \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
+    "        \"model_name\": MODEL_NAME,\n",
+    "        \"seed\": SEED,\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "for epoch in range(NUM_EPOCHS):\n",
+    "    for batch in tqdm(train_dataloader):\n",
+    "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
+    "\n",
+    "        model.train()\n",
+    "        outputs = model(**batch)\n",
+    "        loss = outputs.loss\n",
+    "        loss.backward()\n",
+    "\n",
+    "        optimizer.step()\n",
+    "        lr_scheduler.step()\n",
+    "        optimizer.zero_grad()\n",
+    "\n",
+    "        wandb.log({\"Train Loss\": loss})\n",
+    "\n",
+    "    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
+    "    wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "51770911",
+   "metadata": {},
+   "source": [
+    "Our model have been trained!"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.8.10 64-bit",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.9"
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}