瀏覽代碼

Update advanced notebooks (#148)

Update examples

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Artem Chumachenko 2 年之前
父節點
當前提交
7911c2641d
共有 2 個文件被更改,包括 200 次插入30 次删除
  1. 25 25
      examples/prompt-tuning-personachat.ipynb
  2. 175 5
      examples/prompt-tuning-sst2.ipynb

+ 25 - 25
examples/prompt-tuning-personachat.ipynb

@@ -36,8 +36,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "!pip install git+https://github.com/bigscience-workshop/petals\n",
-    "!pip install datasets wandb"
+    "!pip install -q git+https://github.com/bigscience-workshop/petals\n",
+    "!pip install -q datasets wandb"
    ]
   },
   {
@@ -269,35 +269,35 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "MAX_TOKENS = 16\n",
     "TOP_K = 100\n",
     "TEMPERATURE = 0.6\n",
-    "dialog = \"\"\n",
     "\n",
-    "while True:\n",
-    "    user_phrase = input()\n",
-    "    if len(user_phrase) == 0:\n",
-    "        break\n",
-    "    dialog += f\"{user_phrase}\\n-----\\n\"\n",
-    "    inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
-    "    outputs = model.generate(\n",
-    "        inputs,\n",
-    "        temperature=TEMPERATURE,\n",
-    "        do_sample=True,\n",
-    "        top_k=TOP_K,\n",
-    "        eos_token_id=tokenizer.eos_token_id,\n",
-    "        max_new_tokens=MAX_TOKENS,\n",
-    "    )\n",
-    "    bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
-    "    bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
-    "    print(bloom_answer)\n",
-    "    dialog += f\"{bloom_answer}\\n-----\\n\""
+    "with model.inference_session(max_length=512) as sess:\n",
+    "    while True:\n",
+    "        user_phrase = input()\n",
+    "        if len(user_phrase) == 0:\n",
+    "            break\n",
+    "        inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
+    "        while True:\n",
+    "            outputs = model.generate(\n",
+    "                inputs,\n",
+    "                temperature=TEMPERATURE,\n",
+    "                do_sample=True,\n",
+    "                top_k=TOP_K,\n",
+    "                max_new_tokens=1,\n",
+    "                session=sess,\n",
+    "            )\n",
+    "            bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n",
+    "            print(bloom_answer_token, end=\"\", flush=True)\n",
+    "            if bloom_answer_token == \"\\n\":\n",
+    "                break\n",
+    "            inputs = None"
    ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.12 ('bloom-demo')",
+   "display_name": "Python 3.8.9 64-bit",
    "language": "python",
    "name": "python3"
   },
@@ -311,11 +311,11 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.12"
+   "version": "3.8.9"
   },
   "vscode": {
    "interpreter": {
-    "hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
+    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
    }
   }
  },

+ 175 - 5
examples/prompt-tuning-sst2.ipynb

@@ -36,8 +36,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "!pip install git+https://github.com/bigscience-workshop/petals\n",
-    "!pip install datasets wandb"
+    "!pip install -q git+https://github.com/bigscience-workshop/petals\n",
+    "!pip install -q datasets wandb"
    ]
   },
   {
@@ -52,6 +52,10 @@
     "import torch\n",
     "import transformers\n",
     "import wandb\n",
+    "\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "\n",
     "from datasets import load_dataset, load_metric\n",
     "from tqdm import tqdm\n",
     "from torch.optim import AdamW\n",
@@ -276,11 +280,177 @@
    "source": [
     "Our model have been trained!"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1bbf014f",
+   "metadata": {},
+   "source": [
+    "## Beyond soft-propmt tuning\n",
+    "\n",
+    "Let's try to tune model using adapters in the middle of the model."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3bea4391",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BloomBasedClassifier(nn.Module):\n",
+    "  def __init__(\n",
+    "      self,\n",
+    "      model,\n",
+    "      intermediate_size: int = 32,\n",
+    "      num_classes: int = 2,\n",
+    "      adapter_layer_position: int = 6,\n",
+    "      head_layer_position: int = 10\n",
+    "    ):\n",
+    "    super().__init__()\n",
+    "    self.distributed_layers = model.transformer.h\n",
+    "\n",
+    "    self.hidden_size = model.config.hidden_size\n",
+    "    self.intermediate_size = intermediate_size\n",
+    "    self.num_classes = num_classes\n",
+    "    self.adapter_layer_position = adapter_layer_position\n",
+    "    self.head_layer_position = head_layer_position\n",
+    "    \n",
+    "    self.adapter = nn.Sequential(\n",
+    "        nn.Linear(self.hidden_size, self.intermediate_size),\n",
+    "        nn.Linear(self.intermediate_size, self.hidden_size),\n",
+    "    )\n",
+    "    self.head = nn.Sequential(\n",
+    "        nn.LayerNorm(self.hidden_size),\n",
+    "        nn.Linear(self.hidden_size, self.num_classes),\n",
+    "    )\n",
+    "  \n",
+    "  def forward(self, embeddings):\n",
+    "    before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
+    "    after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
+    "    \n",
+    "    hidden_states = before_layers(embeddings)\n",
+    "    hidden_states = self.adapter(hidden_states)\n",
+    "    hidden_states = after_layers(hidden_states)\n",
+    "    pooled_states = torch.mean(hidden_states, dim=1)\n",
+    "    return self.head(pooled_states)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "15299620",
+   "metadata": {},
+   "source": [
+    "Clear model and device memory."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "aa27b168",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "del model, optimizer, lr_scheduler\n",
+    "torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5406390f",
+   "metadata": {},
+   "source": [
+    "Create new model with adapters."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a251db80",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "INTERMEDIATE_SIZE = 32\n",
+    "ADAPTER_LAYER_POSITION = 6\n",
+    "HEAD_LAYER_POSITION = 10"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3578df3a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
+    "\n",
+    "cls_model = BloomBasedClassifier(\n",
+    "    model,\n",
+    "    intermediate_size=INTERMEDIATE_SIZE,\n",
+    "    adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
+    "    head_layer_position=HEAD_LAYER_POSITION,\n",
+    ")\n",
+    "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+    "\n",
+    "lr_scheduler = get_scheduler(\n",
+    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a40468b9",
+   "metadata": {},
+   "source": [
+    "And start training our new adapted model."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ed051a5d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wandb.init(\n",
+    "    project=\"bloom_based_cls-sst-2\",\n",
+    "    config={\n",
+    "        \"num_epochs\": NUM_EPOCHS,\n",
+    "        \"batch_size\": BATCH_SIZE,\n",
+    "        \"learning_rate\": LR,\n",
+    "        \"weight_decay\": WEIGHT_DECAY,\n",
+    "        \"model_name\": MODEL_NAME,\n",
+    "        \"seed\": SEED,\n",
+    "        \"intermediate_size\": INTERMEDIATE_SIZE,\n",
+    "        \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
+    "        \"head_layer_position\": HEAD_LAYER_POSITION,\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "for epoch in range(NUM_EPOCHS):\n",
+    "    for batch in tqdm(train_dataloader):\n",
+    "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
+    "\n",
+    "        cls_model.train()\n",
+    "        with torch.no_grad():\n",
+    "            embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
+    "        outputs = cls_model(embeddings_output)\n",
+    "        loss.backward()\n",
+    "\n",
+    "        cls_optimizer.step()\n",
+    "        lr_scheduler.step()\n",
+    "        cls_optimizer.zero_grad()\n",
+    "\n",
+    "        wandb.log({\"Train Loss\": loss})\n",
+    "\n",
+    "    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
+    "    wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
+   ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.12 ('bloom-demo')",
+   "display_name": "Python 3.8.9 64-bit",
    "language": "python",
    "name": "python3"
   },
@@ -294,11 +464,11 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.12"
+   "version": "3.8.9"
   },
   "vscode": {
    "interpreter": {
-    "hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
+    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
    }
   }
  },