Explorar o código

Fix examples/sst, add cls_model embeddings (#248)

justheuristic %!s(int64=2) %!d(string=hai) anos
pai
achega
b8a6788490
Modificáronse 1 ficheiros con 4 adicións e 4 borrados
  1. 4 4
      examples/prompt-tuning-sst2.ipynb

+ 4 - 4
examples/prompt-tuning-sst2.ipynb

@@ -288,7 +288,6 @@
    ]
   },
   {
-   "attachments": {},
    "cell_type": "markdown",
    "id": "1bbf014f",
    "metadata": {},
@@ -324,6 +323,7 @@
     "    self.adapter_layer_position = adapter_layer_position\n",
     "    self.head_layer_position = head_layer_position\n",
     "    \n",
+    "    self.word_embeddings = model.transformer.word_embeddings\n",
     "    self.adapter = nn.Sequential(\n",
     "        nn.Linear(self.hidden_size, self.intermediate_size),\n",
     "        nn.Linear(self.intermediate_size, self.hidden_size),\n",
@@ -440,7 +440,7 @@
     "\n",
     "        cls_model.train()\n",
     "        with torch.no_grad():\n",
-    "            embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
+    "            embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
     "        outputs = cls_model(embeddings_output)\n",
     "        loss = cls_criterion(outputs, batch[\"labels\"])\n",
     "        loss.backward()\n",
@@ -458,7 +458,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.9 64-bit",
+   "display_name": "Python 3",
    "language": "python",
    "name": "python3"
   },
@@ -472,7 +472,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
+   "version": "3.8.8"
   },
   "vscode": {
    "interpreter": {