소스 검색

Fix convergence issues and switch to LLaMA in the SST-2 example (#343)

* Fix convergence issues and switch to LLaMA in the SST-2 example
Max Ryabinin 2 년 전
부모
커밋
13f4e3a88a
2개의 변경된 파일105개의 추가작업 그리고 212개의 파일을 삭제
  1. 2 0
      .gitignore
  2. 103 212
      examples/prompt-tuning-sst2.ipynb

+ 2 - 0
.gitignore

@@ -126,3 +126,5 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+.idea/

+ 103 - 212
examples/prompt-tuning-sst2.ipynb

@@ -3,17 +3,19 @@
   {
    "cell_type": "markdown",
    "id": "a07e0f5e",
-   "metadata": {},
+   "metadata": {
+    "id": "a07e0f5e"
+   },
    "source": [
     "<div>\n",
     "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\">  \n",
     "</div>\n",
     "\n",
-    "# Distributed Bloom for Text Classification using Prompt Tuning\n",
+    "# Distributed LLaMA for Text Classification using Prompt Tuning\n",
     "\n",
-    "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
+    "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
     "\n",
-    "We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
+    "We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
     "\n",
     "To use this notebook in Colab:\n",
     "\n",
@@ -24,7 +26,9 @@
   {
    "cell_type": "markdown",
    "id": "a3f8526f",
-   "metadata": {},
+   "metadata": {
+    "id": "a3f8526f"
+   },
    "source": [
     "First, we have to prepare all dependencies."
    ]
@@ -33,17 +37,22 @@
    "cell_type": "code",
    "execution_count": null,
    "id": "73bbc648",
-   "metadata": {},
+   "metadata": {
+    "id": "73bbc648"
+   },
    "outputs": [],
    "source": [
-    "%pip install -q petals datasets wandb scikit-learn"
+    "%pip install -q datasets wandb scikit-learn\n",
+    "%pip install -q git+https://github.com/bigscience-workshop/petals@main"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "b4ab6ca7",
-   "metadata": {},
+   "metadata": {
+    "id": "b4ab6ca7"
+   },
    "outputs": [],
    "source": [
     "import os\n",
@@ -57,15 +66,19 @@
     "from tqdm import tqdm\n",
     "from torch.optim import AdamW\n",
     "from torch.utils.data import DataLoader\n",
-    "from transformers import BloomTokenizerFast, get_scheduler\n",
+    "from transformers import LlamaTokenizer, get_scheduler, set_seed\n",
     "\n",
-    "from petals import DistributedBloomForSequenceClassification"
+    "from petals import DistributedLlamaForSequenceClassification\n",
+    "\n",
+    "set_seed(0)"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "1bf07b5d",
-   "metadata": {},
+   "metadata": {
+    "id": "1bf07b5d"
+   },
    "source": [
     "Let's set some hyperparameters for training:"
    ]
@@ -74,14 +87,15 @@
    "cell_type": "code",
    "execution_count": null,
    "id": "f04ba4d2",
-   "metadata": {},
+   "metadata": {
+    "id": "f04ba4d2"
+   },
    "outputs": [],
    "source": [
     "# Choose a model you'd like to prompt-tune. We recommend starting with\n",
-    "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n",
-    "# Once your code is ready, you can switch to full-scale\n",
-    "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n",
-    "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n",
+    "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n",
+    "# The code below uses LLaMA-65B.\n",
+    "MODEL_NAME = \"enoch/llama-65b-hf\"\n",
     "\n",
     "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
     "# The latter fine-tunes separate prefixes for each transformer block,\n",
@@ -89,9 +103,9 @@
     "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
     "TUNING_MODE = 'ptune'\n",
     "\n",
-    "NUM_PREFIX_TOKENS = 16\n",
+    "NUM_PREFIX_TOKENS = 8\n",
     "DEVICE = 'cuda'\n",
-    "BATCH_SIZE = 16\n",
+    "BATCH_SIZE = 32\n",
     "LR = 1e-2\n",
     "WEIGHT_DECAY = 0.0\n",
     "NUM_EPOCHS = 3\n",
@@ -102,32 +116,40 @@
   {
    "cell_type": "markdown",
    "id": "d38316bd",
-   "metadata": {},
+   "metadata": {
+    "id": "d38316bd"
+   },
    "source": [
-    "Prepare tokenizer and distributed model, connect it to servers."
+    "Here, we prepare tokenizer and distributed model and connect it to the public swarm."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "03c6e53e",
-   "metadata": {},
+   "metadata": {
+    "id": "03c6e53e"
+   },
    "outputs": [],
    "source": [
-    "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+    "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n",
     "tokenizer.padding_side = 'right'\n",
     "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
-    "model = DistributedBloomForSequenceClassification.from_pretrained(\n",
+    "tokenizer.pad_token = tokenizer.unk_token\n",
+    "model = DistributedLlamaForSequenceClassification.from_pretrained(\n",
     "    MODEL_NAME,\n",
     "    pre_seq_len=NUM_PREFIX_TOKENS,\n",
     "    tuning_mode=TUNING_MODE\n",
-    ").to(DEVICE)"
+    ").float().to(DEVICE)\n",
+    "model.config.pad_token_id = tokenizer.pad_token_id"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "042e3786",
-   "metadata": {},
+   "metadata": {
+    "id": "042e3786"
+   },
    "source": [
     "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
    ]
@@ -136,7 +158,9 @@
    "cell_type": "code",
    "execution_count": null,
    "id": "9c44d516",
-   "metadata": {},
+   "metadata": {
+    "id": "9c44d516"
+   },
    "outputs": [],
    "source": [
     "task = 'sst2'\n",
@@ -144,7 +168,7 @@
     "dataset = load_dataset(\"glue\", task)\n",
     "\n",
     "def preprocess_function(examples):\n",
-    "    return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n",
+    "    return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n",
     "\n",
     "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
     "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
@@ -161,16 +185,20 @@
   {
    "cell_type": "markdown",
    "id": "2a3f3590",
-   "metadata": {},
+   "metadata": {
+    "id": "2a3f3590"
+   },
    "source": [
-    "To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library."
+    "To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "1e1812be",
-   "metadata": {},
+   "metadata": {
+    "id": "1e1812be"
+   },
    "outputs": [],
    "source": [
     "metric = load_metric('glue', task)\n",
@@ -179,7 +207,7 @@
     "    model.eval()\n",
     "    for batch in dataloader:\n",
     "        batch = {k: v.to(device) for k, v in batch.items()}\n",
-    "        \n",
+    "\n",
     "        with torch.no_grad():\n",
     "            outputs = model(**batch)\n",
     "\n",
@@ -193,16 +221,20 @@
   {
    "cell_type": "markdown",
    "id": "ef4323fd",
-   "metadata": {},
+   "metadata": {
+    "id": "ef4323fd"
+   },
    "source": [
-    "Before setting up optimizers, check the model parameters that will be trained."
+    "Before setting up optimizers, let's check the model parameters that will be trained."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "9cc0ba34",
-   "metadata": {},
+   "metadata": {
+    "id": "9cc0ba34"
+   },
    "outputs": [],
    "source": [
     "for n, p in model.named_parameters():\n",
@@ -213,29 +245,35 @@
   {
    "cell_type": "markdown",
    "id": "59cffce7",
-   "metadata": {},
+   "metadata": {
+    "id": "59cffce7"
+   },
    "source": [
-    "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
+    "The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "ef9bf344",
-   "metadata": {},
+   "metadata": {
+    "id": "ef9bf344"
+   },
    "outputs": [],
    "source": [
     "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
     "\n",
     "lr_scheduler = get_scheduler(\n",
-    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n",
     ")"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "423c56d5",
-   "metadata": {},
+   "metadata": {
+    "id": "423c56d5"
+   },
    "source": [
     "Let's initialize wandb for logging and start the training loop!"
    ]
@@ -244,7 +282,9 @@
    "cell_type": "code",
    "execution_count": null,
    "id": "d9e46807",
-   "metadata": {},
+   "metadata": {
+    "id": "d9e46807"
+   },
    "outputs": [],
    "source": [
     "wandb.init(\n",
@@ -260,20 +300,24 @@
     "    }\n",
     ")\n",
     "\n",
+    "scaler = torch.cuda.amp.GradScaler()\n",
+    "\n",
     "for epoch in range(NUM_EPOCHS):\n",
+    "    model.train()\n",
     "    for batch in tqdm(train_dataloader):\n",
     "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
     "\n",
-    "        model.train()\n",
-    "        outputs = model(**batch)\n",
+    "        with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n",
+    "          outputs = model(**batch)\n",
     "        loss = outputs.loss\n",
-    "        loss.backward()\n",
+    "        scaler.scale(loss).backward()\n",
     "\n",
-    "        optimizer.step()\n",
+    "        scaler.step(optimizer)\n",
+    "        scaler.update()\n",
     "        lr_scheduler.step()\n",
     "        optimizer.zero_grad()\n",
     "\n",
-    "        wandb.log({\"Train Loss\": loss})\n",
+    "        wandb.log({\"Train Loss\": loss.detach()})\n",
     "\n",
     "    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
     "    wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
@@ -282,184 +326,26 @@
   {
    "cell_type": "markdown",
    "id": "51770911",
-   "metadata": {},
-   "source": [
-    "Our model have been trained!"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "1bbf014f",
-   "metadata": {},
-   "source": [
-    "## Beyond soft-prompt tuning\n",
-    "\n",
-    "Let's try to tune model using adapters in the middle of the model."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "3bea4391",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class BloomBasedClassifier(nn.Module):\n",
-    "  def __init__(\n",
-    "      self,\n",
-    "      model,\n",
-    "      intermediate_size: int = 32,\n",
-    "      num_classes: int = 2,\n",
-    "      adapter_layer_position: int = 6,\n",
-    "      head_layer_position: int = 10\n",
-    "    ):\n",
-    "    super().__init__()\n",
-    "    self.distributed_layers = model.transformer.h\n",
-    "\n",
-    "    self.hidden_size = model.config.hidden_size\n",
-    "    self.dtype = model.config.torch_dtype\n",
-    "    self.intermediate_size = intermediate_size\n",
-    "    self.num_classes = num_classes\n",
-    "    self.adapter_layer_position = adapter_layer_position\n",
-    "    self.head_layer_position = head_layer_position\n",
-    "    \n",
-    "    self.word_embeddings = model.transformer.word_embeddings\n",
-    "    self.adapter = nn.Sequential(\n",
-    "        nn.Linear(self.hidden_size, self.intermediate_size),\n",
-    "        nn.Linear(self.intermediate_size, self.hidden_size),\n",
-    "    ).to(self.dtype)\n",
-    "    self.head = nn.Sequential(\n",
-    "        nn.LayerNorm(self.hidden_size),\n",
-    "        nn.Linear(self.hidden_size, self.num_classes),\n",
-    "    ).to(self.dtype)\n",
-    "  \n",
-    "  def forward(self, embeddings):\n",
-    "    before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
-    "    after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
-    "    \n",
-    "    hidden_states = before_layers(embeddings)\n",
-    "    hidden_states = self.adapter(hidden_states)\n",
-    "    hidden_states = after_layers(hidden_states)\n",
-    "    pooled_states = torch.mean(hidden_states, dim=1)\n",
-    "    return self.head(pooled_states)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "15299620",
-   "metadata": {},
-   "source": [
-    "Clear model and device memory."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "aa27b168",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "del model, optimizer, lr_scheduler\n",
-    "torch.cuda.empty_cache()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "5406390f",
-   "metadata": {},
-   "source": [
-    "Create new model with adapters."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "a251db80",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "INTERMEDIATE_SIZE = 32\n",
-    "ADAPTER_LAYER_POSITION = 6\n",
-    "HEAD_LAYER_POSITION = 10"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "3578df3a",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "cls_model = BloomBasedClassifier(\n",
-    "    DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
-    "    intermediate_size=INTERMEDIATE_SIZE,\n",
-    "    adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
-    "    head_layer_position=HEAD_LAYER_POSITION,\n",
-    ").to(DEVICE)\n",
-    "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
-    "cls_criterion = nn.CrossEntropyLoss()\n",
-    "\n",
-    "lr_scheduler = get_scheduler(\n",
-    "    name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "a40468b9",
-   "metadata": {},
+   "metadata": {
+    "id": "51770911"
+   },
    "source": [
-    "And start training our new adapted model."
+    "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](http://health.petals.ml/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ed051a5d",
-   "metadata": {},
    "outputs": [],
-   "source": [
-    "wandb.init(\n",
-    "    project=\"bloom_based_cls-sst-2\",\n",
-    "    config={\n",
-    "        \"num_epochs\": NUM_EPOCHS,\n",
-    "        \"batch_size\": BATCH_SIZE,\n",
-    "        \"learning_rate\": LR,\n",
-    "        \"weight_decay\": WEIGHT_DECAY,\n",
-    "        \"model_name\": MODEL_NAME,\n",
-    "        \"seed\": SEED,\n",
-    "        \"intermediate_size\": INTERMEDIATE_SIZE,\n",
-    "        \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
-    "        \"head_layer_position\": HEAD_LAYER_POSITION,\n",
-    "    }\n",
-    ")\n",
-    "\n",
-    "for epoch in range(NUM_EPOCHS):\n",
-    "    for batch in tqdm(train_dataloader):\n",
-    "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
-    "\n",
-    "        cls_model.train()\n",
-    "        with torch.no_grad():\n",
-    "            embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
-    "        outputs = cls_model(embeddings_output)\n",
-    "        loss = cls_criterion(outputs, batch[\"labels\"])\n",
-    "        loss.backward()\n",
-    "\n",
-    "        cls_optimizer.step()\n",
-    "        lr_scheduler.step()\n",
-    "        cls_optimizer.zero_grad()\n",
-    "\n",
-    "        wandb.log({\"Train Loss\": loss})\n",
-    "\n",
-    "    accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
-    "    wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
-   ]
+   "source": [],
+   "metadata": {
+    "collapsed": false
+   }
   }
  ],
  "metadata": {
   "kernelspec": {
    "display_name": "Python 3",
-   "language": "python",
    "name": "python3"
   },
   "language_info": {
@@ -478,7 +364,12 @@
    "interpreter": {
     "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
    }
-  }
+  },
+  "colab": {
+   "provenance": [],
+   "gpuType": "T4"
+  },
+  "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5