|
@@ -36,8 +36,8 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "!pip install git+https://github.com/bigscience-workshop/petals\n",
|
|
|
- "!pip install datasets wandb"
|
|
|
+ "!pip install -q git+https://github.com/bigscience-workshop/petals\n",
|
|
|
+ "!pip install -q datasets wandb"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -52,6 +52,10 @@
|
|
|
"import torch\n",
|
|
|
"import transformers\n",
|
|
|
"import wandb\n",
|
|
|
+ "\n",
|
|
|
+ "import torch.nn as nn\n",
|
|
|
+ "import torch.nn.functional as F\n",
|
|
|
+ "\n",
|
|
|
"from datasets import load_dataset, load_metric\n",
|
|
|
"from tqdm import tqdm\n",
|
|
|
"from torch.optim import AdamW\n",
|
|
@@ -276,11 +280,177 @@
|
|
|
"source": [
|
|
|
"Our model have been trained!"
|
|
|
]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "1bbf014f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Beyond soft-propmt tuning\n",
|
|
|
+ "\n",
|
|
|
+ "Let's try to tune model using adapters in the middle of the model."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "3bea4391",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class BloomBasedClassifier(nn.Module):\n",
|
|
|
+ " def __init__(\n",
|
|
|
+ " self,\n",
|
|
|
+ " model,\n",
|
|
|
+ " intermediate_size: int = 32,\n",
|
|
|
+ " num_classes: int = 2,\n",
|
|
|
+ " adapter_layer_position: int = 6,\n",
|
|
|
+ " head_layer_position: int = 10\n",
|
|
|
+ " ):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.distributed_layers = model.transformer.h\n",
|
|
|
+ "\n",
|
|
|
+ " self.hidden_size = model.config.hidden_size\n",
|
|
|
+ " self.intermediate_size = intermediate_size\n",
|
|
|
+ " self.num_classes = num_classes\n",
|
|
|
+ " self.adapter_layer_position = adapter_layer_position\n",
|
|
|
+ " self.head_layer_position = head_layer_position\n",
|
|
|
+ " \n",
|
|
|
+ " self.adapter = nn.Sequential(\n",
|
|
|
+ " nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
|
+ " nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
|
+ " )\n",
|
|
|
+ " self.head = nn.Sequential(\n",
|
|
|
+ " nn.LayerNorm(self.hidden_size),\n",
|
|
|
+ " nn.Linear(self.hidden_size, self.num_classes),\n",
|
|
|
+ " )\n",
|
|
|
+ " \n",
|
|
|
+ " def forward(self, embeddings):\n",
|
|
|
+ " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
|
|
|
+ " after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
|
|
|
+ " \n",
|
|
|
+ " hidden_states = before_layers(embeddings)\n",
|
|
|
+ " hidden_states = self.adapter(hidden_states)\n",
|
|
|
+ " hidden_states = after_layers(hidden_states)\n",
|
|
|
+ " pooled_states = torch.mean(hidden_states, dim=1)\n",
|
|
|
+ " return self.head(pooled_states)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "15299620",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Clear model and device memory."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "aa27b168",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "del model, optimizer, lr_scheduler\n",
|
|
|
+ "torch.cuda.empty_cache()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "5406390f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Create new model with adapters."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "a251db80",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "INTERMEDIATE_SIZE = 32\n",
|
|
|
+ "ADAPTER_LAYER_POSITION = 6\n",
|
|
|
+ "HEAD_LAYER_POSITION = 10"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "3578df3a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
|
|
|
+ "\n",
|
|
|
+ "cls_model = BloomBasedClassifier(\n",
|
|
|
+ " model,\n",
|
|
|
+ " intermediate_size=INTERMEDIATE_SIZE,\n",
|
|
|
+ " adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
|
|
+ " head_layer_position=HEAD_LAYER_POSITION,\n",
|
|
|
+ ")\n",
|
|
|
+ "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
|
+ "\n",
|
|
|
+ "lr_scheduler = get_scheduler(\n",
|
|
|
+ " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "a40468b9",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "And start training our new adapted model."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "ed051a5d",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "wandb.init(\n",
|
|
|
+ " project=\"bloom_based_cls-sst-2\",\n",
|
|
|
+ " config={\n",
|
|
|
+ " \"num_epochs\": NUM_EPOCHS,\n",
|
|
|
+ " \"batch_size\": BATCH_SIZE,\n",
|
|
|
+ " \"learning_rate\": LR,\n",
|
|
|
+ " \"weight_decay\": WEIGHT_DECAY,\n",
|
|
|
+ " \"model_name\": MODEL_NAME,\n",
|
|
|
+ " \"seed\": SEED,\n",
|
|
|
+ " \"intermediate_size\": INTERMEDIATE_SIZE,\n",
|
|
|
+ " \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
|
|
|
+ " \"head_layer_position\": HEAD_LAYER_POSITION,\n",
|
|
|
+ " }\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "for epoch in range(NUM_EPOCHS):\n",
|
|
|
+ " for batch in tqdm(train_dataloader):\n",
|
|
|
+ " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
|
+ "\n",
|
|
|
+ " cls_model.train()\n",
|
|
|
+ " with torch.no_grad():\n",
|
|
|
+ " embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
|
|
|
+ " outputs = cls_model(embeddings_output)\n",
|
|
|
+ " loss.backward()\n",
|
|
|
+ "\n",
|
|
|
+ " cls_optimizer.step()\n",
|
|
|
+ " lr_scheduler.step()\n",
|
|
|
+ " cls_optimizer.zero_grad()\n",
|
|
|
+ "\n",
|
|
|
+ " wandb.log({\"Train Loss\": loss})\n",
|
|
|
+ "\n",
|
|
|
+ " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
|
|
+ " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
|
|
+ ]
|
|
|
}
|
|
|
],
|
|
|
"metadata": {
|
|
|
"kernelspec": {
|
|
|
- "display_name": "Python 3.8.12 ('bloom-demo')",
|
|
|
+ "display_name": "Python 3.8.9 64-bit",
|
|
|
"language": "python",
|
|
|
"name": "python3"
|
|
|
},
|
|
@@ -294,11 +464,11 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.8.12"
|
|
|
+ "version": "3.8.9"
|
|
|
},
|
|
|
"vscode": {
|
|
|
"interpreter": {
|
|
|
- "hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
|
|
|
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
|
}
|
|
|
}
|
|
|
},
|