Просмотр исходного кода

Minor changes to examples/prompt-tuning notebooks (#247)

Minor code changes required to run the notebook in a clean python environment
justheuristic 2 лет назад
Родитель
Сommit
8766a14d28
2 измененных файлов с 8 добавлено и 10 удалено
  1. 2 2
      examples/prompt-tuning-personachat.ipynb
  2. 6 8
      examples/prompt-tuning-sst2.ipynb

+ 2 - 2
examples/prompt-tuning-personachat.ipynb

@@ -36,7 +36,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "%pip install -q petals datasets wandb"
+    "%pip install -q petals datasets wandb scikit-learn"
    ]
   },
   {
@@ -285,7 +285,7 @@
     "        user_phrase = input()\n",
     "        if len(user_phrase) == 0:\n",
     "            break\n",
-    "        inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
+    "        inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
     "        while True:\n",
     "            outputs = model.generate(\n",
     "                inputs,\n",

+ 6 - 8
examples/prompt-tuning-sst2.ipynb

@@ -36,7 +36,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "%pip install -q petals datasets wandb"
+    "%pip install -q petals datasets wandb scikit-learn"
    ]
   },
   {
@@ -390,16 +390,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
-    "\n",
     "cls_model = BloomBasedClassifier(\n",
-    "    model,\n",
+    "    DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
     "    intermediate_size=INTERMEDIATE_SIZE,\n",
     "    adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
     "    head_layer_position=HEAD_LAYER_POSITION,\n",
-    ")\n",
+    ").to(DEVICE)\n",
     "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
-    "cls_criterion = nn.CrossEntropyCriterion()\n",
+    "cls_criterion = nn.CrossEntropyLoss()\n",
     "\n",
     "lr_scheduler = get_scheduler(\n",
     "    name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
@@ -442,7 +440,7 @@
     "\n",
     "        cls_model.train()\n",
     "        with torch.no_grad():\n",
-    "            embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
+    "            embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
     "        outputs = cls_model(embeddings_output)\n",
     "        loss = cls_criterion(outputs, batch[\"labels\"])\n",
     "        loss.backward()\n",
@@ -453,7 +451,7 @@
     "\n",
     "        wandb.log({\"Train Loss\": loss})\n",
     "\n",
-    "    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
+    "    accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
     "    wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
    ]
   }