소스 검색

Fix dtype error in fine-tuning notebooks (#231)

Artem Chumachenko 2 년 전
부모
커밋
d4c687daca
2개의 변경된 파일8개의 추가작업 그리고 5개의 파일을 삭제
  1. 7 4
      examples/prompt-tuning-sst2.ipynb
  2. 1 1
      src/petals/client/remote_model.py

+ 7 - 4
examples/prompt-tuning-sst2.ipynb

@@ -308,6 +308,7 @@
     "    self.distributed_layers = model.transformer.h\n",
     "\n",
     "    self.hidden_size = model.config.hidden_size\n",
+    "    self.dtype = model.config.torch_dtype\n",
     "    self.intermediate_size = intermediate_size\n",
     "    self.num_classes = num_classes\n",
     "    self.adapter_layer_position = adapter_layer_position\n",
@@ -316,11 +317,11 @@
     "    self.adapter = nn.Sequential(\n",
     "        nn.Linear(self.hidden_size, self.intermediate_size),\n",
     "        nn.Linear(self.intermediate_size, self.hidden_size),\n",
-    "    )\n",
+    "    ).to(self.dtype)\n",
     "    self.head = nn.Sequential(\n",
     "        nn.LayerNorm(self.hidden_size),\n",
     "        nn.Linear(self.hidden_size, self.num_classes),\n",
-    "    )\n",
+    "    ).to(self.dtype)\n",
     "  \n",
     "  def forward(self, embeddings):\n",
     "    before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
@@ -388,9 +389,10 @@
     "    head_layer_position=HEAD_LAYER_POSITION,\n",
     ")\n",
     "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+    "cls_criterion = nn.CrossEntoryCriterion()\n",
     "\n",
     "lr_scheduler = get_scheduler(\n",
-    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+    "    name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
     ")"
    ]
   },
@@ -432,6 +434,7 @@
     "        with torch.no_grad():\n",
     "            embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
     "        outputs = cls_model(embeddings_output)\n",
+    "        loss = cls_criterion(outputs, batch[\"labels\"])\n",
     "        loss.backward()\n",
     "\n",
     "        cls_optimizer.step()\n",
@@ -461,7 +464,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]"
+   "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
   },
   "vscode": {
    "interpreter": {

+ 1 - 1
src/petals/client/remote_model.py

@@ -265,7 +265,7 @@ class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequ
         self.num_labels = config.num_labels
 
         self.transformer = DistributedBloomModel(config)
-        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
 
         # Initialize weights and apply final processing
         self.post_init()