{ "cells": [ { "cell_type": "markdown", "id": "d3014ad2", "metadata": {}, "source": [ "# Our Bloom for Sequence Classification using Prompt Tuning\n", "\n", "In this example, we showcase how Bloom model can be efficiently adapted in decentralized fashion. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens and a classification head. " ] }, { "cell_type": "markdown", "id": "243a8971", "metadata": {}, "source": [ "### Import all dependences and prepare the environment" ] }, { "cell_type": "code", "execution_count": 1, "id": "7e975fd7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: CUDA_VISIBLE_DEVICES=6\n", "env: OMP_NUM_THREADS=1\n" ] } ], "source": [ "%env CUDA_VISIBLE_DEVICES=6\n", "%env OMP_NUM_THREADS=1\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "# General \n", "import torch\n", "import pandas as pd\n", "from tqdm.auto import tqdm\n", "from torch.utils.data import DataLoader\n", "\n", "# Distributed\n", "from src.bloom.model import BloomForSequenceClassification\n", "\n", "# HF imports\n", "import transformers\n", "import datasets\n", "from datasets import load_dataset, load_metric\n", "\n", "# Visualization dependencies\n", "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.style.use('seaborn-whitegrid')\n", "plt.rcParams['pdf.fonttype'] = 42\n", "plt.rcParams['ps.fonttype'] = 42\n", "\n", "def print_params(model):\n", " for n, p in model.named_parameters():\n", " print(n, p.requires_grad, p.device)\n", " if p.requires_grad:\n", " print(p)" ] }, { "cell_type": "markdown", "id": "fa04b0b4", "metadata": {}, "source": [ "### Config" ] }, { "cell_type": "code", "execution_count": 2, "id": "291f6021", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME='bigscience/bloom-6b3'\n", "\n", "PROMPT_TUNING_TYPE='deep'\n", "NUM_PREFIX_TOKENS=16\n", "NUM_LABELS=2\n", "DEVICE='cuda'\n", "BATCH_SIZE=32" ] }, { "cell_type": "markdown", "id": "006c934d", "metadata": {}, "source": [ "### Prepare the distributed Bloom model" ] }, { "cell_type": "code", "execution_count": 3, "id": "bff5d710", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloom-6b3 and are newly initialized: ['intermediate_prompt_embeddings.weight', 'score.weight', 'prompt_embeddings.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "transformer.word_embeddings.weight False cuda:0\n", "transformer.word_embeddings_layernorm.weight False cuda:0\n", "transformer.word_embeddings_layernorm.bias False cuda:0\n", "transformer.h.0.input_layernorm.weight False cuda:0\n", "transformer.h.0.input_layernorm.bias False cuda:0\n", "transformer.h.0.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.0.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.0.self_attention.dense.weight False cuda:0\n", "transformer.h.0.self_attention.dense.bias False cuda:0\n", "transformer.h.0.post_attention_layernorm.weight False cuda:0\n", "transformer.h.0.post_attention_layernorm.bias False cuda:0\n", "transformer.h.0.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.0.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.0.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.0.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.1.input_layernorm.weight False cuda:0\n", "transformer.h.1.input_layernorm.bias False cuda:0\n", "transformer.h.1.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.1.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.1.self_attention.dense.weight False cuda:0\n", "transformer.h.1.self_attention.dense.bias False cuda:0\n", "transformer.h.1.post_attention_layernorm.weight False cuda:0\n", "transformer.h.1.post_attention_layernorm.bias False cuda:0\n", "transformer.h.1.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.1.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.1.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.1.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.2.input_layernorm.weight False cuda:0\n", "transformer.h.2.input_layernorm.bias False cuda:0\n", "transformer.h.2.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.2.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.2.self_attention.dense.weight False cuda:0\n", "transformer.h.2.self_attention.dense.bias False cuda:0\n", "transformer.h.2.post_attention_layernorm.weight False cuda:0\n", "transformer.h.2.post_attention_layernorm.bias False cuda:0\n", "transformer.h.2.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.2.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.2.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.2.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.3.input_layernorm.weight False cuda:0\n", "transformer.h.3.input_layernorm.bias False cuda:0\n", "transformer.h.3.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.3.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.3.self_attention.dense.weight False cuda:0\n", "transformer.h.3.self_attention.dense.bias False cuda:0\n", "transformer.h.3.post_attention_layernorm.weight False cuda:0\n", "transformer.h.3.post_attention_layernorm.bias False cuda:0\n", "transformer.h.3.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.3.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.3.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.3.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.4.input_layernorm.weight False cuda:0\n", "transformer.h.4.input_layernorm.bias False cuda:0\n", "transformer.h.4.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.4.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.4.self_attention.dense.weight False cuda:0\n", "transformer.h.4.self_attention.dense.bias False cuda:0\n", "transformer.h.4.post_attention_layernorm.weight False cuda:0\n", "transformer.h.4.post_attention_layernorm.bias False cuda:0\n", "transformer.h.4.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.4.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.4.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.4.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.5.input_layernorm.weight False cuda:0\n", "transformer.h.5.input_layernorm.bias False cuda:0\n", "transformer.h.5.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.5.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.5.self_attention.dense.weight False cuda:0\n", "transformer.h.5.self_attention.dense.bias False cuda:0\n", "transformer.h.5.post_attention_layernorm.weight False cuda:0\n", "transformer.h.5.post_attention_layernorm.bias False cuda:0\n", "transformer.h.5.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.5.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.5.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.5.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.6.input_layernorm.weight False cuda:0\n", "transformer.h.6.input_layernorm.bias False cuda:0\n", "transformer.h.6.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.6.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.6.self_attention.dense.weight False cuda:0\n", "transformer.h.6.self_attention.dense.bias False cuda:0\n", "transformer.h.6.post_attention_layernorm.weight False cuda:0\n", "transformer.h.6.post_attention_layernorm.bias False cuda:0\n", "transformer.h.6.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.6.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.6.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.6.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.7.input_layernorm.weight False cuda:0\n", "transformer.h.7.input_layernorm.bias False cuda:0\n", "transformer.h.7.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.7.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.7.self_attention.dense.weight False cuda:0\n", "transformer.h.7.self_attention.dense.bias False cuda:0\n", "transformer.h.7.post_attention_layernorm.weight False cuda:0\n", "transformer.h.7.post_attention_layernorm.bias False cuda:0\n", "transformer.h.7.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.7.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.7.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.7.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.8.input_layernorm.weight False cuda:0\n", "transformer.h.8.input_layernorm.bias False cuda:0\n", "transformer.h.8.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.8.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.8.self_attention.dense.weight False cuda:0\n", "transformer.h.8.self_attention.dense.bias False cuda:0\n", "transformer.h.8.post_attention_layernorm.weight False cuda:0\n", "transformer.h.8.post_attention_layernorm.bias False cuda:0\n", "transformer.h.8.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.8.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.8.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.8.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.9.input_layernorm.weight False cuda:0\n", "transformer.h.9.input_layernorm.bias False cuda:0\n", "transformer.h.9.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.9.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.9.self_attention.dense.weight False cuda:0\n", "transformer.h.9.self_attention.dense.bias False cuda:0\n", "transformer.h.9.post_attention_layernorm.weight False cuda:0\n", "transformer.h.9.post_attention_layernorm.bias False cuda:0\n", "transformer.h.9.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.9.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.9.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.9.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.10.input_layernorm.weight False cuda:0\n", "transformer.h.10.input_layernorm.bias False cuda:0\n", "transformer.h.10.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.10.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.10.self_attention.dense.weight False cuda:0\n", "transformer.h.10.self_attention.dense.bias False cuda:0\n", "transformer.h.10.post_attention_layernorm.weight False cuda:0\n", "transformer.h.10.post_attention_layernorm.bias False cuda:0\n", "transformer.h.10.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.10.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.10.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.10.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.11.input_layernorm.weight False cuda:0\n", "transformer.h.11.input_layernorm.bias False cuda:0\n", "transformer.h.11.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.11.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.11.self_attention.dense.weight False cuda:0\n", "transformer.h.11.self_attention.dense.bias False cuda:0\n", "transformer.h.11.post_attention_layernorm.weight False cuda:0\n", "transformer.h.11.post_attention_layernorm.bias False cuda:0\n", "transformer.h.11.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.11.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.11.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.11.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.12.input_layernorm.weight False cuda:0\n", "transformer.h.12.input_layernorm.bias False cuda:0\n", "transformer.h.12.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.12.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.12.self_attention.dense.weight False cuda:0\n", "transformer.h.12.self_attention.dense.bias False cuda:0\n", "transformer.h.12.post_attention_layernorm.weight False cuda:0\n", "transformer.h.12.post_attention_layernorm.bias False cuda:0\n", "transformer.h.12.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.12.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.12.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.12.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.13.input_layernorm.weight False cuda:0\n", "transformer.h.13.input_layernorm.bias False cuda:0\n", "transformer.h.13.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.13.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.13.self_attention.dense.weight False cuda:0\n", "transformer.h.13.self_attention.dense.bias False cuda:0\n", "transformer.h.13.post_attention_layernorm.weight False cuda:0\n", "transformer.h.13.post_attention_layernorm.bias False cuda:0\n", "transformer.h.13.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.13.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.13.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.13.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.14.input_layernorm.weight False cuda:0\n", "transformer.h.14.input_layernorm.bias False cuda:0\n", "transformer.h.14.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.14.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.14.self_attention.dense.weight False cuda:0\n", "transformer.h.14.self_attention.dense.bias False cuda:0\n", "transformer.h.14.post_attention_layernorm.weight False cuda:0\n", "transformer.h.14.post_attention_layernorm.bias False cuda:0\n", "transformer.h.14.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.14.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.14.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.14.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.15.input_layernorm.weight False cuda:0\n", "transformer.h.15.input_layernorm.bias False cuda:0\n", "transformer.h.15.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.15.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.15.self_attention.dense.weight False cuda:0\n", "transformer.h.15.self_attention.dense.bias False cuda:0\n", "transformer.h.15.post_attention_layernorm.weight False cuda:0\n", "transformer.h.15.post_attention_layernorm.bias False cuda:0\n", "transformer.h.15.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.15.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.15.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.15.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.16.input_layernorm.weight False cuda:0\n", "transformer.h.16.input_layernorm.bias False cuda:0\n", "transformer.h.16.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.16.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.16.self_attention.dense.weight False cuda:0\n", "transformer.h.16.self_attention.dense.bias False cuda:0\n", "transformer.h.16.post_attention_layernorm.weight False cuda:0\n", "transformer.h.16.post_attention_layernorm.bias False cuda:0\n", "transformer.h.16.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.16.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.16.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.16.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.17.input_layernorm.weight False cuda:0\n", "transformer.h.17.input_layernorm.bias False cuda:0\n", "transformer.h.17.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.17.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.17.self_attention.dense.weight False cuda:0\n", "transformer.h.17.self_attention.dense.bias False cuda:0\n", "transformer.h.17.post_attention_layernorm.weight False cuda:0\n", "transformer.h.17.post_attention_layernorm.bias False cuda:0\n", "transformer.h.17.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.17.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.17.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.17.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.18.input_layernorm.weight False cuda:0\n", "transformer.h.18.input_layernorm.bias False cuda:0\n", "transformer.h.18.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.18.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.18.self_attention.dense.weight False cuda:0\n", "transformer.h.18.self_attention.dense.bias False cuda:0\n", "transformer.h.18.post_attention_layernorm.weight False cuda:0\n", "transformer.h.18.post_attention_layernorm.bias False cuda:0\n", "transformer.h.18.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.18.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.18.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.18.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.19.input_layernorm.weight False cuda:0\n", "transformer.h.19.input_layernorm.bias False cuda:0\n", "transformer.h.19.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.19.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.19.self_attention.dense.weight False cuda:0\n", "transformer.h.19.self_attention.dense.bias False cuda:0\n", "transformer.h.19.post_attention_layernorm.weight False cuda:0\n", "transformer.h.19.post_attention_layernorm.bias False cuda:0\n", "transformer.h.19.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.19.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.19.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.19.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.20.input_layernorm.weight False cuda:0\n", "transformer.h.20.input_layernorm.bias False cuda:0\n", "transformer.h.20.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.20.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.20.self_attention.dense.weight False cuda:0\n", "transformer.h.20.self_attention.dense.bias False cuda:0\n", "transformer.h.20.post_attention_layernorm.weight False cuda:0\n", "transformer.h.20.post_attention_layernorm.bias False cuda:0\n", "transformer.h.20.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.20.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.20.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.20.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.21.input_layernorm.weight False cuda:0\n", "transformer.h.21.input_layernorm.bias False cuda:0\n", "transformer.h.21.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.21.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.21.self_attention.dense.weight False cuda:0\n", "transformer.h.21.self_attention.dense.bias False cuda:0\n", "transformer.h.21.post_attention_layernorm.weight False cuda:0\n", "transformer.h.21.post_attention_layernorm.bias False cuda:0\n", "transformer.h.21.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.21.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.21.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.21.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.22.input_layernorm.weight False cuda:0\n", "transformer.h.22.input_layernorm.bias False cuda:0\n", "transformer.h.22.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.22.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.22.self_attention.dense.weight False cuda:0\n", "transformer.h.22.self_attention.dense.bias False cuda:0\n", "transformer.h.22.post_attention_layernorm.weight False cuda:0\n", "transformer.h.22.post_attention_layernorm.bias False cuda:0\n", "transformer.h.22.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.22.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.22.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.22.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.23.input_layernorm.weight False cuda:0\n", "transformer.h.23.input_layernorm.bias False cuda:0\n", "transformer.h.23.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.23.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.23.self_attention.dense.weight False cuda:0\n", "transformer.h.23.self_attention.dense.bias False cuda:0\n", "transformer.h.23.post_attention_layernorm.weight False cuda:0\n", "transformer.h.23.post_attention_layernorm.bias False cuda:0\n", "transformer.h.23.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.23.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.23.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.23.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.24.input_layernorm.weight False cuda:0\n", "transformer.h.24.input_layernorm.bias False cuda:0\n", "transformer.h.24.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.24.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.24.self_attention.dense.weight False cuda:0\n", "transformer.h.24.self_attention.dense.bias False cuda:0\n", "transformer.h.24.post_attention_layernorm.weight False cuda:0\n", "transformer.h.24.post_attention_layernorm.bias False cuda:0\n", "transformer.h.24.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.24.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.24.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.24.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.25.input_layernorm.weight False cuda:0\n", "transformer.h.25.input_layernorm.bias False cuda:0\n", "transformer.h.25.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.25.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.25.self_attention.dense.weight False cuda:0\n", "transformer.h.25.self_attention.dense.bias False cuda:0\n", "transformer.h.25.post_attention_layernorm.weight False cuda:0\n", "transformer.h.25.post_attention_layernorm.bias False cuda:0\n", "transformer.h.25.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.25.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.25.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.25.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.26.input_layernorm.weight False cuda:0\n", "transformer.h.26.input_layernorm.bias False cuda:0\n", "transformer.h.26.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.26.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.26.self_attention.dense.weight False cuda:0\n", "transformer.h.26.self_attention.dense.bias False cuda:0\n", "transformer.h.26.post_attention_layernorm.weight False cuda:0\n", "transformer.h.26.post_attention_layernorm.bias False cuda:0\n", "transformer.h.26.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.26.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.26.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.26.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.27.input_layernorm.weight False cuda:0\n", "transformer.h.27.input_layernorm.bias False cuda:0\n", "transformer.h.27.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.27.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.27.self_attention.dense.weight False cuda:0\n", "transformer.h.27.self_attention.dense.bias False cuda:0\n", "transformer.h.27.post_attention_layernorm.weight False cuda:0\n", "transformer.h.27.post_attention_layernorm.bias False cuda:0\n", "transformer.h.27.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.27.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.27.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.27.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.28.input_layernorm.weight False cuda:0\n", "transformer.h.28.input_layernorm.bias False cuda:0\n", "transformer.h.28.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.28.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.28.self_attention.dense.weight False cuda:0\n", "transformer.h.28.self_attention.dense.bias False cuda:0\n", "transformer.h.28.post_attention_layernorm.weight False cuda:0\n", "transformer.h.28.post_attention_layernorm.bias False cuda:0\n", "transformer.h.28.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.28.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.28.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.28.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.h.29.input_layernorm.weight False cuda:0\n", "transformer.h.29.input_layernorm.bias False cuda:0\n", "transformer.h.29.self_attention.query_key_value.weight False cuda:0\n", "transformer.h.29.self_attention.query_key_value.bias False cuda:0\n", "transformer.h.29.self_attention.dense.weight False cuda:0\n", "transformer.h.29.self_attention.dense.bias False cuda:0\n", "transformer.h.29.post_attention_layernorm.weight False cuda:0\n", "transformer.h.29.post_attention_layernorm.bias False cuda:0\n", "transformer.h.29.mlp.dense_h_to_4h.weight False cuda:0\n", "transformer.h.29.mlp.dense_h_to_4h.bias False cuda:0\n", "transformer.h.29.mlp.dense_4h_to_h.weight False cuda:0\n", "transformer.h.29.mlp.dense_4h_to_h.bias False cuda:0\n", "transformer.ln_f.weight False cuda:0\n", "transformer.ln_f.bias False cuda:0\n", "transformer.prompt_embeddings.weight True cuda:0\n", "Parameter containing:\n", "tensor([[ 0.0289, 0.0230, 0.0049, ..., -0.0024, -0.0144, -0.0053],\n", " [-0.0165, 0.0022, 0.0458, ..., -0.0156, 0.0053, -0.0038],\n", " [ 0.0039, -0.0245, -0.0135, ..., -0.0011, 0.0008, 0.0165],\n", " ...,\n", " [-0.0042, 0.0283, 0.0045, ..., -0.0233, 0.0101, 0.0013],\n", " [-0.0258, -0.0271, 0.0120, ..., 0.0169, 0.0161, -0.0055],\n", " [ 0.0060, -0.0394, 0.0309, ..., 0.0285, -0.0300, 0.0055]],\n", " device='cuda:0', requires_grad=True)\n", "transformer.intermediate_prompt_embeddings.weight True cuda:0\n", "Parameter containing:\n", "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0', requires_grad=True)\n", "score.weight True cuda:0\n", "Parameter containing:\n", "tensor([[ 0.0007, -0.0262, 0.0132, ..., 0.0213, -0.0057, -0.0152],\n", " [ 0.0161, 0.0150, -0.0190, ..., -0.0135, -0.0326, 0.0023]],\n", " device='cuda:0', requires_grad=True)\n" ] } ], "source": [ "model = BloomForSequenceClassification.from_pretrained(\n", " MODEL_NAME, \n", "# num_prefix_tokens=NUM_PREFIX_TOKENS, \n", " num_labels=NUM_LABELS\n", ").to(DEVICE)\n", " \n", "for name, p in model.named_parameters():\n", " if 'score' in name or 'prompt' in name:\n", " p.requires_grad = True\n", " else:\n", " p.requires_grad = False\n", "\n", "model.transformer.intermediate_prompt_embeddings.weight.data.zero_()\n", "print_params(model)" ] }, { "cell_type": "code", "execution_count": 4, "id": "95b03a47", "metadata": {}, "outputs": [], "source": [ "# params = dict(our_model.named_parameters())\n", "# for name, p in model.named_parameters():\n", "# print(f'Check: {name}')\n", "# assert torch.allclose(p, params[name]), name" ] }, { "cell_type": "markdown", "id": "901b50fd", "metadata": {}, "source": [ "## Dataset\n", "\n", "This examples operates on SST-2 dataset for binary sentence classification." ] }, { "cell_type": "code", "execution_count": 5, "id": "864fd3dc", "metadata": {}, "outputs": [], "source": [ "import random\n", "from IPython.display import display, HTML\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": 6, "id": "18ca24c8", "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", "task = 'sst2'" ] }, { "cell_type": "code", "execution_count": 7, "id": "5d110a5e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Jul 26 05:43:35.435 [WARN] [datasets.builder.download_and_prepare:641] Reusing dataset glue (/home/dbaranchuk/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4cadddcffa164eeaa7307bb1dcfba416", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e6d294129ce343f9bba3768a844c03de", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/68 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Val accuracies: [0.9495412844036697, 0.9403669724770642, 0.9403669724770642, 0.944954128440367, 0.9495412844036697, 0.944954128440367]\n" ] } ], "source": [ "progress_bar = tqdm(range(num_training_steps))\n", "\n", "loss_history = []\n", "accuracy_history = []\n", "\n", "for epoch in range(num_epochs):\n", " for batch in train_dataloader:\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", " \n", " model.train()\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", " loss_history.append(loss.item())\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)\n", " \n", " print(f'Iteration: {progress_bar.n}')\n", " plt.figure(figsize=[20, 8])\n", " plt.subplot(1,2,1)\n", " plt.title('Train loss over time', fontsize=12); plt.grid();\n", " plt.plot(moving_average(loss_history, span=10))\n", " plt.scatter(range(len(loss_history)), loss_history, alpha=0.1)\n", "\n", " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " accuracy_history.append(accuracy['accuracy'])\n", "\n", " plt.subplot(1,2,2)\n", " plt.title('Val accuracy', fontsize=12); plt.grid();\n", " plt.plot(accuracy_history)\n", " plt.show()\n", " print('Val accuracies: ', accuracy_history)" ] }, { "cell_type": "code", "execution_count": 15, "id": "197b3f4b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 10525\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Val accuracies: [0.9472477064220184, 0.9438073394495413, 0.9438073394495413, 0.9495412844036697, 0.9552752293577982]\n" ] } ], "source": [ "print(f'Iteration: {progress_bar.n}')\n", "plt.figure(figsize=[20, 8])\n", "plt.subplot(1,2,1)\n", "plt.title('Train loss over time', fontsize=12); plt.grid();\n", "plt.plot(moving_average(loss_history, span=10))\n", "plt.scatter(range(len(loss_history)), loss_history, alpha=0.1)\n", "\n", "plt.subplot(1,2,2)\n", "plt.title('Val accuracy', fontsize=12); plt.grid();\n", "plt.plot(accuracy_history)\n", "plt.show()\n", "print('Val accuracies: ', accuracy_history)" ] }, { "cell_type": "code", "execution_count": 16, "id": "06daf317", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.10097131777508184" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(loss_history[-1000:]) / 1000" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }